知识图谱源码详解【八】__init__.py
生活随笔
收集整理的這篇文章主要介紹了
知识图谱源码详解【八】__init__.py
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
import torch
from src.model.DKN.KCNN import KCNN
from src.model.DKN.attention import Attention
from src.model.general.click_predictor.DNN import DNNClickPredictor# 就是把整個模型框架梳理到一塊了! class DKN(torch.nn.Module):"""Deep knowledge-aware network.Input 1 + K candidate news and a list of user clicked news, produce the click probability."""#純定義,如果看過前面的內容,這里不難理解def __init__(self,config,pretrained_word_embedding=None,pretrained_entity_embedding=None,pretrained_context_embedding=None):super(DKN, self).__init__()self.config = configself.kcnn = KCNN(config, pretrained_word_embedding,pretrained_entity_embedding,pretrained_context_embedding)self.attention = Attention(config)self.click_predictor = DNNClickPredictor(len(self.config.window_sizes) * 2 * self.config.num_filters)def forward(self, candidate_news, clicked_news):"""Args:candidate_news:[{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title} * (1 + K)]clicked_news:[{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title} * num_clicked_news_a_user]Returns:click_probability: batch_size"""# batch_size, 1 + K, len(window_sizes) * num_filterscandidate_news_vector = torch.stack( [self.kcnn(x) for x in candidate_news], dim=1)# batch_size, num_clicked_news_a_user, len(window_sizes) * num_filtersclicked_news_vector = torch.stack([self.kcnn(x) for x in clicked_news],dim=1)# batch_size, 1 + K, len(window_sizes) * num_filtersuser_vector = torch.stack([self.attention(x, clicked_news_vector)for x in candidate_news_vector.transpose(0, 1)],dim=1)size = candidate_news_vector.size()# batch_size, 1 + Kclick_probability = self.click_predictor(candidate_news_vector.view(size[0] * size[1], size[2]),user_vector.view(size[0] * size[1],size[2])).view(size[0], size[1])return click_probabilitydef get_news_vector(self, news):"""Args:news:{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title}Returns:(shape) batch_size, len(window_sizes) * num_filters"""# batch_size, len(window_sizes) * num_filtersreturn self.kcnn(news)def get_user_vector(self, clicked_news_vector):"""Args:clicked_news_vector: batch_size, num_clicked_news_a_user, len(window_sizes) * num_filtersReturns:(shape) batch_size, num_clicked_news_a_user, len(window_sizes) * num_filters"""# batch_size, num_clicked_news_a_user, len(window_sizes) * num_filtersreturn clicked_news_vectordef get_prediction(self, candidate_news_vector, clicked_news_vector):"""Args:candidate_news_vector: candidate_size, len(window_sizes) * num_filtersclicked_news_vector: num_clicked_news_a_user, len(window_sizes) * num_filtersReturns:click_probability: 0-dim tensor"""# candidate_size, len(window_sizes) * num_filtersuser_vector = self.attention(candidate_news_vector,clicked_news_vector.expand(candidate_news_vector.size(0), -1, -1))# candidate_sizeclick_probability = self.click_predictor(candidate_news_vector,user_vector)return click_probability
總結
以上是生活随笔為你收集整理的知识图谱源码详解【八】__init__.py的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PAT真题乙类1006 换个格式输出整数
- 下一篇: 幻侠修仙服务器维护,幻侠修仙常见问题_幻