日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當(dāng)前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

知识表示学习 TransE 代码逻辑梳理 超详细解析

發(fā)布時間:2023/12/2 编程问答 26 豆豆
生活随笔 收集整理的這篇文章主要介紹了 知识表示学习 TransE 代码逻辑梳理 超详细解析 小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.

知識表示學(xué)習(xí)

網(wǎng)絡(luò)上已經(jīng)存在了大量知識庫(KBs),比如OpenCyc,WordNet,Freebase,Dbpedia等等。

這些知識庫是為了各種各樣的目的建立的,因此很難用到其他系統(tǒng)上面。為了發(fā)揮知識庫的圖(graph)性,也為了得到統(tǒng)計學(xué)習(xí)(包括機器學(xué)習(xí)和深度學(xué)習(xí))的優(yōu)勢,我們需要將知識庫嵌入(embedding)到一個低維空間里(比如10、20、50維)。我們都知道,獲得了向量后,就可以運用各種數(shù)學(xué)工具進行分析。它為許多知識獲取任務(wù)和下游應(yīng)用鋪平了道路。
總的來說,廢話這么多,所謂知識表示學(xué)習(xí),就是將知識庫給映射成向量,同時滿足屬于同一個三元組的(h,t,l)滿足h+l≈t,而不是一個三元組的不滿足這個條件。

TransE思路

TranE是一篇Bordes等人2013年發(fā)表在NIPS上的文章提出的算法。它的提出,是為了解決多關(guān)系數(shù)據(jù)(multi-relational data)的處理問題。我們現(xiàn)在有很多很多的知識庫數(shù)據(jù)knowledge bases (KBs),比如Freebase、 Google Knowledge Graph 、 GeneOntology等等。
TransE的直觀含義,就是TransE基于實體和關(guān)系的分布式向量表示,將每個三元組實例(head,relation,tail)中的關(guān)系relation看做從實體head到實體tail的翻譯(其實我一直很納悶為什么叫做translating,其實就是向量相加),通過不斷調(diào)整h、r和t(head、relation和tail的向量),使(h + r) 盡可能與 t 相等,即 h + r = t。

損失函數(shù)是

TransE代碼邏輯梳理

首先注明,該代碼不是出自我手,但由于最近需要使用并修改TransE,故從github上找到一個還不錯的TransE實現(xiàn),對其進行閱讀,并梳理其邏輯,為后續(xù)工作做好鋪墊。貼上其github鏈接,感謝前人辛苦付出。https://github.com/wuxiyu/transE/blob/master/tranE.py
下面對其代碼進行分析。
首先,這里將整個代碼封裝成了一個類,該類的構(gòu)造方法(由于平常用的語言是java,python只當(dāng)做工具語言,沒有系統(tǒng)學(xué)過語法,所以用詞可能不當(dāng),見諒)中需要的參數(shù)如下所示:

:param entityList: 實體列表,讀取文本文件,實體+id:param relationList: 關(guān)系列表,讀取文本文件,關(guān)系+id:param tripleList: 三元組列表,讀取文本文件,實體+實體+關(guān)系:param margin: 表示正負(fù)樣本之間的間距,是一個超參數(shù),也就是公式中Loss里的γ:param learingRate: 學(xué)習(xí)率,其實就是梯度下降中的步長:param dim: 向量維度,即h,t,l向量的維度是1*dim,因為最終我們所有的實體和關(guān)系都是要表示為向量:param L1: 距離公式采用矩陣1范數(shù)還是矩陣2范數(shù)

首先,我們將目光放到main方法,從main方法開始整個TransE的旅程。

dirEntity = "C:\\data\\entity2id.txt"entityIdNum, entityList = openDetailsAndId(dirEntity)dirRelation = "C:\\data\\relation2id.txt"relationIdNum, relationList = openDetailsAndId(dirRelation)dirTrain = "C:\\data\\train.txt"tripleNum, tripleList = openTrain(dirTrain)print("打開TransE")transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100)print("TranE初始化")transE.initialize()transE.transE(15000)transE.writeRelationVector("c:\\relationVector.txt")transE.writeEntilyVector("c:\\entityVector.txt")

首先是通過三個Open方法分別獲取實體數(shù)量和實體列表、關(guān)系總數(shù)量和關(guān)系列表、三元組總數(shù)量和三元組列表。獲取需要的數(shù)據(jù)。
例如,其中entityList是一個list,其樣式就為[05451384,04958634,00620424,....];
而relationList樣式為["_member_of_domain_topic","_member_meronym"...];
而tripleList例如[(03964744,04371774,_hyponym), (....)....],其中全是三元組,都是(h,t,l)的格式。
至于那些Num們,都只是用于計數(shù)?并沒發(fā)現(xiàn)用在哪里,也不用管

然后就是實例化TransE這個類了,將實體列表,關(guān)系列表,和三元組列表放進去,設(shè)置間距γ為1(這個是超參數(shù),可以調(diào)),然后對于輸出向量,其維度設(shè)為100(這個也可以自己指定)。

之后調(diào)用transE的initialize()方法,進行初始化。這里初始化具體做了什么呢?答曰初始化向量,構(gòu)建字典集合,分別來裝實體向量們和關(guān)系向量們。那么問題就來了,這個向量如何生成呢,之前我們手里只有05451384這串?dāng)?shù)字來代表實體,但是,并沒有向量啊。這里采用的方式就是···隨機生成,對于個100維的向量,隨機生成它,方式為每一個數(shù)字都是在-6/(dim**0.5), 6/(dim**0.5)之間隨機生成,然后構(gòu)成一個100個元素的列表,即代表這個實體的向量,同時,將這個實體和其對應(yīng)的隨機生成的向量放入新創(chuàng)建的字典entityVectorList中去,同理對于關(guān)系也是如此操作。當(dāng)然,在向量生成之后對其做一個歸一化,保證它是單位向量,做法就是每個元素除以元素總和的平方和的開平方,具體見norm方法,這個很簡單。

entityVectorList = {}relationVectorList = {}for entity in self.entityList:n = 0entityVector = []while n < self.dim:ram = init(self.dim)entityVector.append(ram) #注意到這里的ram和entity是毫無關(guān)系的,是一個隨機的值,所以這里append之后,就是一個dim個元素的列表n += 1entityVector = norm(entityVector)#歸一化entityVectorList[entity] = entityVector

至此,我們便為每個關(guān)系和實體生成了一個向量,向量是一個100維的列表。
然后我們將entityList和relationList賦值成這兩個字典,也就是我們最初的entityList是列表,而經(jīng)過初始化之后卻變成了字典,字典的樣式為{實體名:對應(yīng)向量,…}

之后,下一步就是進行訓(xùn)練了。調(diào)用transE的transE()方法,其中輸入的15000意為迭代的次數(shù)。

for cycleIndex in range(cI):#迭代cI次Sbatch = self.getSample(150) #隨機獲取150個三元組Tbatch = []#元組對(原三元組,打碎的三元組)的列表 :[((h,r,t),(h',r,t'))]for sbatch in Sbatch:#遍歷獲取到的元組,并獲取它們的打碎三元組,從而獲得<=150個元組對(防止重復(fù))tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) #將sbatch傳入,獲取打碎的三元組,然后構(gòu)成一個元組對if(tripletWithCorruptedTriplet not in Tbatch):Tbatch.append(tripletWithCorruptedTriplet)self.update(Tbatch)#對整個集合進行更新if cycleIndex % 100 == 0:print("第%d次循環(huán)"%cycleIndex)print(self.loss)self.writeRelationVector("c:\\relationVector.txt")self.writeEntilyVector("c:\\entityVector.txt")self.loss = 0

這里cI參數(shù)就是迭代次數(shù)。
首先是調(diào)用getSample()方法,該方法作用為在tripleList中隨機選取size個三元組并返回。所以這里的Sbatch就是隨機獲取的150個三元組。
然后Tbatch是一個新創(chuàng)建的列表,用于存儲元組(元組是tuple,是python中一種數(shù)據(jù)結(jié)構(gòu),而三元組是知識圖譜的一種結(jié)構(gòu),不要搞亂了),其中的樣式為[((h,r,t),(h',r,t'))]。
下面就是對Sbatch進行遍歷,遍歷每一個三元組,調(diào)用getCorruptedTriplet()方法來獲取某個三元組的打碎的三元組,也就是在上面算法中提到的,對一個三元組,我們假定它是h+l=t的,此時我們創(chuàng)建一個范例,一個絕對不滿足假設(shè)的,如何創(chuàng)建呢,任意用別的h或t來替換掉我們這里的h或t,從而得到一個錯誤的三元組,即打碎的三元組(我也不知道為啥叫打碎,不過挺有意思哈哈哈)。將打碎的三元組和正確的三元組放在一起組成一個新的元組,然后將其放入Tbatch列表中,當(dāng)然這里有個去重的判斷,很簡單,就不說了哈。
下面的操作就是最重要的了,進行更新。
首先,要明確,這里的更新,只是針對我們隨機選出來的150個三元組進行更新。然后,更新什么呢?當(dāng)然是更新它們的向量,所以假設(shè)我們的h,t都互不相同,那么這里最多也就更新了300個實體的向量,(關(guān)系因為數(shù)量肯定沒那么多,就不舉例了)。然后更新的方式是什么,那就是通過梯度下降法來求得損失函數(shù)的最小值,從而獲得一個最優(yōu)的向量們。

好,下面我們來看這個更新操作,這里是調(diào)用update()方法,將剛才的Tbatch傳入。
首先在該方法的開始,進行了兩次拷貝,將實體列表(其實是實體-向量字典)和關(guān)系列表分別進行拷貝,目的是為了之后更新,不相互影響。然后關(guān)于deepcopy和copy的區(qū)別大家可以去查一下,簡單來說就是前者copy的更徹底,列表或字典中的每個元素都單獨拷貝了一份。
然后便是遍歷這里的Tbatch,對每個元組進行操作。
首先是前面是一長串的賦值操作,選其中一個來說明。

headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]

首先我們知道tripletWithCorruptedTriplet的格式是這樣的[((h,r,t),(h',r,t'))],那[0][0]就是獲取其中的h實體,然后根據(jù)h實體在entityList字典中獲取其對應(yīng)的向量。如此便是,其余也皆是同理。
然后根據(jù)L1參數(shù)是否為true來使用矩陣1范數(shù)或矩陣2范數(shù),因為不同范數(shù)它的梯度是不一樣的。
我們接下來矩陣2范數(shù)即L1==false來進行說明。此時進行計算Loss損失函數(shù)的值,根據(jù)公式γ+d(h+l,t)?d(h′+l,t′)\gamma+d(h+l,t)-d(h'+l,t')γ+d(h+l,t)?d(h+l,t)來計算,當(dāng)然這里的d(h+l,t)d(h+l,t)d(h+l,t)要進行展開,就是普通的距離公式,展開之后的Loss函數(shù)為γ+(h+l?t)2?(h′+l?t′)2\gamma+(h+l-t)^{2}-(h'+l-t')^{2}γ+(h+l?t)2?(h+l?t)2,等一下,是不是主要到這里和之前說的有些不同,對的,這里沒有求和符號,因為這里相當(dāng)于是把總的Loss給分開算的,所以沒有求和符號了。累加起來便有。
然后當(dāng)這個損失函數(shù)的值>0時,才進行更新,否則不進行更新。這里解釋一下為什么這么操作。如此操作的原因在于我們喜歡正確的三元組的向量們滿足h+l≈t,而打碎的三元組不滿足,則正確三元組距離應(yīng)該接近于0,而錯誤的應(yīng)為一個不小的正值(因為是矩陣2范數(shù)),然后此時必然有損失函數(shù)值e<0的情況。當(dāng)然,你也會說那假如兩個值都不小,剛好前者小于后者呢,這種情況少,且沒必要要求這么高,畢竟可以近似,同時這是算法層級的問題,這里不再討論。
當(dāng)e>0時,我們進行更新,這里更新的操作就是一個很簡單的梯度下降方法。下面來介紹一下。首先損失函數(shù)Loss是γ+(h+l?t)2?(h′+l?t′)2\gamma+(h+l-t)^{2}-(h'+l-t')^{2}γ+(h+l?t)2?(h+l?t)2,我們對其h進行求導(dǎo)得其梯度,則其結(jié)果是??h=2(h+l?t)\frac{\partial }{\partial h} = 2(h+l-t)?h??=2(h+l?t),則h更新為h?=h?u???h=h?u?2?(h+l?t)=h+u?2?(t?h?l)h^{*}=h-u*\frac{\partial }{\partial h}=h-u*2*(h+l-t)=h+u*2*(t-h-l)h?=h?u??h??=h?u?2?(h+l?t)=h+u?2?(t?h?l),這里的u是梯度下降的步長,也就是上面提到的學(xué)習(xí)率,同理,t的更新也是一樣,t?=t?u?2?(t?h?l)t^{*}=t-u*2*(t-h-l)t?=t?u?2?(t?h?l),然后同理l也是一樣l?=l+u?2?(t?h?l)?u?2?(t′?h′?l)l^{*}=l+u*2*(t-h-l)-u*2*(t'-h'-l)l?=l+u?2?(t?h?l)?u?2?(t?h?l)。
如此,進行更新,然后進行歸一化,最終更新總的entityList和relationList。

至此,更新過程結(jié)束,至于后面的向量寫入文件這里就不贅述了。

完整代碼

這里代碼我都加上了較為詳細的注釋,可以結(jié)合上面的代碼梳理進行理解。

from random import uniform, sample from numpy import * from copy import deepcopyclass TransE:def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):''':param entityList: 實體列表,讀取文本文件,實體+id:param relationList: 關(guān)系列表,讀取文本文件,關(guān)系+id:param tripleList: 三元組列表,讀取文本文件,實體+實體+關(guān)系:param margin: gamma,目標(biāo)函數(shù)的常數(shù):param learingRate: 學(xué)習(xí)率:param dim: 向量維度,也就是h,t,l向量的維度是1*dim:param L1: 距離公式'''self.margin = marginself.learingRate = learingRateself.dim = dim#向量維度self.entityList = entityList#一開始,entityList是entity的list;初始化后,變?yōu)樽值?#xff0c;key是entity,values是其向量(使用narray)。self.relationList = relationList#理由同上self.tripleList = tripleList#理由同上self.loss = 0self.L1 = L1def initialize(self):'''初始化向量'''entityVectorList = {}relationVectorList = {}for entity in self.entityList:n = 0entityVector = []while n < self.dim:ram = init(self.dim)#初始化的范圍entityVector.append(ram) #注意到這里的ram和entity是毫無關(guān)系的,是一個隨機的值,所以這里append之后,就是一個dim個元素的列表n += 1entityVector = norm(entityVector)#歸一化entityVectorList[entity] = entityVectorprint("entityVector初始化完成,數(shù)量是%d"%len(entityVectorList))for relation in self. relationList:n = 0relationVector = []while n < self.dim:ram = init(self.dim)#初始化的范圍relationVector.append(ram)n += 1relationVector = norm(relationVector)#歸一化relationVectorList[relation] = relationVectorprint("relationVectorList初始化完成,數(shù)量是%d"%len(relationVectorList))self.entityList = entityVectorListself.relationList = relationVectorListdef transE(self, cI = 20):print("訓(xùn)練開始")for cycleIndex in range(cI):#迭代cI次Sbatch = self.getSample(150) #隨機獲取150個三元組Tbatch = []#元組對(原三元組,打碎的三元組)的列表 :{((h,r,t),(h',r,t'))}for sbatch in Sbatch:#遍歷獲取到的元組,并獲取它們的打碎三元組,從而獲得<=150個元組對(防止重復(fù))tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) #將sbatch傳入,獲取打碎的三元組,然后構(gòu)成一個元組對if(tripletWithCorruptedTriplet not in Tbatch):Tbatch.append(tripletWithCorruptedTriplet)self.update(Tbatch)#對整個集合進行更新if cycleIndex % 100 == 0:print("第%d次循環(huán)"%cycleIndex)print(self.loss)self.writeRelationVector("c:\\relationVector.txt")self.writeEntilyVector("c:\\entityVector.txt")self.loss = 0def getSample(self, size):'''隨機選取部分三元關(guān)系 sbatch:param size::return:'''return sample(self.tripleList, size) #從tripleList中隨機獲取size個元素def getCorruptedTriplet(self, triplet):'''training triplets with either the head or tail replaced by a random entity (but not both at the same time)隨機替換三元組的實體,h和t中任意一個被替換,但不同時替換。也就是構(gòu)建損壞的三元組集合:param triplet::return corruptedTriplet:'''i = uniform(-1, 1)if i < 0:#小于0,打壞三元組的第一項while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[0]:breakcorruptedTriplet = (entityTemp, triplet[1], triplet[2])else:#大于等于0,打壞三元組的第二項while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[1]:breakcorruptedTriplet = (triplet[0], entityTemp, triplet[2])return corruptedTripletdef update(self, Tbatch):'''進行更新,更新的過程就是一個梯度下降:param Tbatch::return:'''copyEntityList = deepcopy(self.entityList) #copy和deepcopy的區(qū)別在于,copy只拷貝整體,若局部改變,則拷貝整體的局部也改變,而deepcopy則全部拷貝過去copyRelationList = deepcopy(self.relationList)for tripletWithCorruptedTriplet in Tbatch:#遍歷整個元組,最多迭代150次# 這里的索引很好理解((h,t,l)(h',t',l)) 但是copyEntityList[h]# 懂了,這里EntityList是類似于字典的,有id與向量這兩個東西,所以是輸入id,獲取向量headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元組和打碎的三元組的元組tupletailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]#下面的也是一模一樣,感覺只是為了備份一份,進行比較headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元組和打碎的三元組的元組tupletailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]if self.L1:#這L1啥意思···哦是L1范數(shù)distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)else:#否則L2范數(shù)distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)eg = self.margin + distTriplet - distCorruptedTriplet #損失函數(shù) 就跟論文上公式是一樣的if eg > 0: #[function]+ 是一個取正值的函數(shù) 似乎是只有大于0時才進行更新,想一下,也確實,因為前一個距離應(yīng)該為0,后一個不為0,然后,0-正<0則不用改,正-正>則需要改self.loss += egif self.L1:#這個學(xué)習(xí)率有點懵tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)tempPositiveL1 = []tempNegtativeL1 = []for i in range(self.dim):#不知道有沒有pythonic的寫法(比如列表推倒或者numpy的函數(shù))?if tempPositive[i] >= 0:tempPositiveL1.append(1)else:tempPositiveL1.append(-1)if tempNegtative[i] >= 0:tempNegtativeL1.append(1)else:tempNegtativeL1.append(-1)tempPositive = array(tempPositiveL1) tempNegtative = array(tempNegtativeL1)else:#這里學(xué)習(xí)率就是y?對,應(yīng)該這里的學(xué)習(xí)率就是梯度下降中的步長#然后括號里是t-h-ltempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)#進行更新headEntityVector = headEntityVector + tempPositive #h* = h + 增量tailEntityVector = tailEntityVector - tempPositive #t* = t - 增量relationVector = relationVector + tempPositive - tempNegtative #l* = l +y*2(t-h-l) -y*2(t'-h'-l)headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative #同理tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative #同理#只歸一化這幾個剛更新的向量,而不是按原論文那些一口氣全更新了copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)self.entityList = copyEntityList #進行更新self.relationList = copyRelationListdef writeEntilyVector(self, dir):print("寫入實體")entityVectorFile = open(dir, 'w')for entity in self.entityList.keys():entityVectorFile.write(entity+"\t")entityVectorFile.write(str(self.entityList[entity].tolist()))entityVectorFile.write("\n")entityVectorFile.close()def writeRelationVector(self, dir):print("寫入關(guān)系")relationVectorFile = open(dir, 'w')for relation in self.relationList.keys():relationVectorFile.write(relation + "\t")relationVectorFile.write(str(self.relationList[relation].tolist()))relationVectorFile.write("\n")relationVectorFile.close()def init(dim):'''向量初始化,隨機生成值:param dim: 維度:return:'''return uniform(-6/(dim**0.5), 6/(dim**0.5)) #uniform(a, b)#隨機生成a,b之間的數(shù),左閉右開def distanceL1(h, t ,r):s = h + r - tsum = fabs(s).sum()return sumdef distanceL2(h, t, r):'''這里是對向量進行操作的,所以有個sum:param h: 這里的都是向量:param t::param r::return:'''s = h + r - tsum = (s*s).sum()return sumdef norm(list):'''歸一化:param 向量:return: 向量/向量的能量'''var = linalg.norm(list)i = 0while i < len(list):list[i] = list[i]/vari += 1return array(list)def openDetailsAndId(dir,sp="\t"):idNum = 0list = []with open(dir) as file:lines = file.readlines()for line in lines:DetailsAndId = line.strip().split(sp)list.append(DetailsAndId[0])idNum += 1return idNum, listdef openTrain(dir,sp="\t"):num = 0list = []with open(dir) as file:lines = file.readlines()for line in lines:triple = line.strip().split(sp)if(len(triple)<3):continuelist.append(tuple(triple))num += 1return num, listif __name__ == '__main__':dirEntity = "C:\\data\\entity2id.txt"entityIdNum, entityList = openDetailsAndId(dirEntity)dirRelation = "C:\\data\\relation2id.txt"relationIdNum, relationList = openDetailsAndId(dirRelation)dirTrain = "C:\\data\\train.txt"tripleNum, tripleList = openTrain(dirTrain)print("打開TransE")transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100)print("TranE初始化")transE.initialize()transE.transE(15000)transE.writeRelationVector("c:\\relationVector.txt")transE.writeEntilyVector("c:\\entityVector.txt")

參考資料

https://blog.csdn.net/u011274209/article/details/50991385
https://blog.csdn.net/jiayalu/article/details/100543909
https://github.com/wuxiyu/transE/blob/master/tranE.py

總結(jié)

以上是生活随笔為你收集整理的知识表示学习 TransE 代码逻辑梳理 超详细解析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯,歡迎將生活随笔推薦給好友。