生活随笔
收集整理的這篇文章主要介紹了
TransE模型的简单介绍TransE模型的python代码实现
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
模型介紹
TransE模型的基本思想是使head向量和relation向量的和盡可能靠近tail向量。這里我們用L1或L2范數來衡量它們的靠近程度。
損失函數是使用了負抽樣的max-margin函數。
L(y, y’) = max(0, margin - y + y’)
y是正樣本的得分,y'是負樣本的得分。然后使損失函數值最小化,當這兩個分數之間的差距大于margin的時候就可以了(我們會設置這個值,通常是1)。
由于我們使用距離來表示得分,所以我們在公式中加上一個減號,知識表示的損失函數為:
其中,d是:
這是L1或L2范數。至于如何得到負樣本,則是將head實體或tail實體替換為三元組中的隨機實體。
代碼實現:
具體的代碼和數據集(YAGO、umls、FB15K、WN18)請見Github:
https://github.com/Colinasda/TransE.git
import codecs
import numpy
as np
import copy
import time
import randomentities2id
= {}
relations2id
= {}def dataloader(file1
, file2
, file3
):print("load file...")entity
= []relation
= []with open(file2
, 'r') as f1
, open(file3
, 'r') as f2
:lines1
= f1
.readlines
()lines2
= f2
.readlines
()for line
in lines1
:line
= line
.strip
().split
('\t')if len(line
) != 2:continueentities2id
[line
[0]] = line
[1]entity
.append
(line
[1])for line
in lines2
:line
= line
.strip
().split
('\t')if len(line
) != 2:continuerelations2id
[line
[0]] = line
[1]relation
.append
(line
[1])triple_list
= []with codecs
.open(file1
, 'r') as f
:content
= f
.readlines
()for line
in content
:triple
= line
.strip
().split
("\t")if len(triple
) != 3:continueh_
= entities2id
[triple
[0]]r_
= relations2id
[triple
[1]]t_
= entities2id
[triple
[2]]triple_list
.append
([h_
, r_
, t_
])print("Complete load. entity : %d , relation : %d , triple : %d" % (len(entity
), len(relation
), len(triple_list
)))return entity
, relation
, triple_list
def norm_l1(h
, r
, t
):return np
.sum(np
.fabs
(h
+ r
- t
))def norm_l2(h
, r
, t
):return np
.sum(np
.square
(h
+ r
- t
))class TransE:def __init__(self
, entity
, relation
, triple_list
, embedding_dim
=50, lr
=0.01, margin
=1.0, norm
=1):self
.entities
= entityself
.relations
= relationself
.triples
= triple_listself
.dimension
= embedding_dimself
.learning_rate
= lrself
.margin
= marginself
.norm
= normself
.loss
= 0.0def data_initialise(self
):entityVectorList
= {}relationVectorList
= {}for entity
in self
.entities
:entity_vector
= np
.random
.uniform
(-6.0 / np
.sqrt
(self
.dimension
), 6.0 / np
.sqrt
(self
.dimension
),self
.dimension
)entityVectorList
[entity
] = entity_vector
for relation
in self
.relations
:relation_vector
= np
.random
.uniform
(-6.0 / np
.sqrt
(self
.dimension
), 6.0 / np
.sqrt
(self
.dimension
),self
.dimension
)relation_vector
= self
.normalization
(relation_vector
)relationVectorList
[relation
] = relation_vectorself
.entities
= entityVectorListself
.relations
= relationVectorList
def normalization(self
, vector
):return vector
/ np
.linalg
.norm
(vector
)def training_run(self
, epochs
=1, nbatches
=100, out_file_title
= ''):batch_size
= int(len(self
.triples
) / nbatches
)print("batch size: ", batch_size
)for epoch
in range(epochs
):start
= time
.time
()self
.loss
= 0.0for entity
in self
.entities
.keys
():self
.entities
[entity
] = self
.normalization
(self
.entities
[entity
]);for batch
in range(nbatches
):batch_samples
= random
.sample
(self
.triples
, batch_size
)Tbatch
= []for sample
in batch_samples
:corrupted_sample
= copy
.deepcopy
(sample
)pr
= np
.random
.random
(1)[0]if pr
> 0.5:corrupted_sample
[0] = random
.sample
(self
.entities
.keys
(), 1)[0]while corrupted_sample
[0] == sample
[0]:corrupted_sample
[0] = random
.sample
(self
.entities
.keys
(), 1)[0]else:corrupted_sample
[2] = random
.sample
(self
.entities
.keys
(), 1)[0]while corrupted_sample
[2] == sample
[2]:corrupted_sample
[2] = random
.sample
(self
.entities
.keys
(), 1)[0]if (sample
, corrupted_sample
) not in Tbatch
:Tbatch
.append
((sample
, corrupted_sample
))self
.update_triple_embedding
(Tbatch
)end
= time
.time
()print("epoch: ", epoch
, "cost time: %s" % (round((end
- start
), 3)))print("running loss: ", self
.loss
)with codecs
.open(out_file_title
+"TransE_entity_" + str(self
.dimension
) + "dim_batch" + str(batch_size
), "w") as f1
:for e
in self
.entities
.keys
():f1
.write
(str(list(self
.entities
[e
])))f1
.write
("\n")with codecs
.open(out_file_title
+"TransE_relation_" + str(self
.dimension
) + "dim_batch" + str(batch_size
), "w") as f2
:for r
in self
.relations
.keys
():f2
.write
(str(list(self
.relations
[r
])))f2
.write
("\n")def update_triple_embedding(self
, Tbatch
):copy_entity
= copy
.deepcopy
(self
.entities
)copy_relation
= copy
.deepcopy
(self
.relations
)for correct_sample
, corrupted_sample
in Tbatch
:correct_copy_head
= copy_entity
[correct_sample
[0]]correct_copy_tail
= copy_entity
[correct_sample
[2]]relation_copy
= copy_relation
[correct_sample
[1]]corrupted_copy_head
= copy_entity
[corrupted_sample
[0]]corrupted_copy_tail
= copy_entity
[corrupted_sample
[2]]correct_head
= self
.entities
[correct_sample
[0]]correct_tail
= self
.entities
[correct_sample
[2]]relation
= self
.relations
[correct_sample
[1]]corrupted_head
= self
.entities
[corrupted_sample
[0]]corrupted_tail
= self
.entities
[corrupted_sample
[2]]if self
.norm
== 1:correct_distance
= norm_l1
(correct_head
, relation
, correct_tail
)corrupted_distance
= norm_l1
(corrupted_head
, relation
, corrupted_tail
)else:correct_distance
= norm_l2
(correct_head
, relation
, correct_tail
)corrupted_distance
= norm_l2
(corrupted_head
, relation
, corrupted_tail
)loss
= self
.margin
+ correct_distance
- corrupted_distance
if loss
> 0:self
.loss
+= loss
print(loss
)correct_gradient
= 2 * (correct_head
+ relation
- correct_tail
)corrupted_gradient
= 2 * (corrupted_head
+ relation
- corrupted_tail
)if self
.norm
== 1:for i
in range(len(correct_gradient
)):if correct_gradient
[i
] > 0:correct_gradient
[i
] = 1else:correct_gradient
[i
] = -1if corrupted_gradient
[i
] > 0:corrupted_gradient
[i
] = 1else:corrupted_gradient
[i
] = -1correct_copy_head
-= self
.learning_rate
* correct_gradientrelation_copy
-= self
.learning_rate
* correct_gradientcorrect_copy_tail
-= -1 * self
.learning_rate
* correct_gradientrelation_copy
-= -1 * self
.learning_rate
* corrupted_gradient
if correct_sample
[0] == corrupted_sample
[0]:correct_copy_head
-= -1 * self
.learning_rate
* corrupted_gradientcorrupted_copy_tail
-= self
.learning_rate
* corrupted_gradient
elif correct_sample
[2] == corrupted_sample
[2]:corrupted_copy_head
-= -1 * self
.learning_rate
* corrupted_gradientcorrect_copy_tail
-= self
.learning_rate
* corrupted_gradientcopy_entity
[correct_sample
[0]] = self
.normalization
(correct_copy_head
)copy_entity
[correct_sample
[2]] = self
.normalization
(correct_copy_tail
)if correct_sample
[0] == corrupted_sample
[0]:copy_entity
[corrupted_sample
[2]] = self
.normalization
(corrupted_copy_tail
)elif correct_sample
[2] == corrupted_sample
[2]:copy_entity
[corrupted_sample
[0]] = self
.normalization
(corrupted_copy_head
)copy_relation
[correct_sample
[1]] = relation_copyself
.entities
= copy_entityself
.relations
= copy_relation
if __name__
== '__main__':file1
= "/umls/train.txt"file2
= "/umls/entity2id.txt"file3
= "/umls/relation2id.txt"entity_set
, relation_set
, triple_list
= dataloader
(file1
, file2
, file3
)transE
= TransE
(entity_set
, relation_set
, triple_list
, embedding_dim
=30, lr
=0.01, margin
=1.0, norm
=2)transE
.data_initialise
()transE
.training_run
(out_file_title
="umls_")
總結
以上是生活随笔為你收集整理的TransE模型的简单介绍TransE模型的python代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。