Learning Deep Structured Semantic Models for Web Search using Clickthrough Data (DSSM)
主要研究問題:
? ? ? ?給定一個查詢(query)和一組文檔(document),返回一個排序(ranking),系統根據查詢所對應文檔的契合度高低排序。
?
論文主要結構:
?
一、Abstract
1、基于關鍵詞匹配的潛在語義模型經常失敗
2、模型利用歷史點擊給定查詢和一組文檔,最大化匹配被點擊過的文檔的概率
3、采用詞哈希技術以便能應對大規模的網絡搜索
4、利用真實的網頁排名數據做實驗,結果顯示DSSM明顯優于其它模型
二、Introduction
1、潛在語義分析模型(如LSA)本身的約束性,模型質量達到瓶頸
2、互聯網極速發展帶來的巨量的點擊記錄,LSA等面臨模型存儲的挑戰
3、詞語不采用one-hot形式,而是映射到向量空間給模型的設計提供新的思路
三、Related Work
主要介紹了點擊數據中潛在語義模型的應用以及深度學習的介紹
四、Structured Semantic Models
? ? ??
?
1、Term Vector: “詞”向量-詞袋模型編碼
2、Word Hashing : 詞哈希
3、Multi-layer non-linear projection: 多層感知機
4、Relevance measured by cosine similarity: 余弦相似度計算
5、Poster probablity computed by softmax: 通過softmax計算每一個文檔匹配的概率
詞哈希部分說到,可能會出現沖突的情況,但是概率較低,可以忽略不計
詞哈希主要是使形態上相近的單詞有相近的表達,有比較強的魯棒性,大幅度緩解oov問題
trick: 使用xavier初始化網絡參數服從均勻分布,使用SGD梯度下降方法
五、Experiments
?
模型9-沒有使用哈希、模型10-去掉全連接部分、模型11-不使用任何非線性函數、模型12-完整的DSSM模型
缺少詞哈希對結果影響較大,原因是直接由詞袋編碼轉為稠密化向量不經過詞哈希這一層緩沖,信息損失嚴重
六、Conclusion
?創新點:
? ? ? ?1、運用深度神經網絡(DNN)接收和查詢文檔
2、利用非線性函數將文本映射到語義空間
3、利用詞哈希解決詞表過大問題
4、利用余弦相似度計算查詢與多個文檔間的相似程度
5、不同于word2vec,DSSM為有監督訓練
?
關鍵點:
1、利用點擊數據設計文檔排序實驗
2、利用非線性激活函數提取語義特征
3、使用詞哈希技術使得模型能夠應用于大規模世紀生產環境
七、Code
數據集:?MRPC(Microsoft Research Paraphrase Corpus,也有的成其為MSRP)數據資源 , MRPC 是由微軟研究院提供的開源文本語義相似數據集。句子對來源于對同一條新聞的評論. 判斷這一對句子在語義上是否相同
""" 數據處理部分 """# encoding = 'utf-8'vocab = [] file_path = "./MRPC/" files = ['train_data.csv','test_data.csv']def n_gram(word,n=3):s = []word = "#" + word + "#"for i in range(len(word)-(n-1)):s.append(word[i:i+3])return sdef lst_gram(lst,n=3):s = []for word in str(lst).lower().split():s.extend(n_gram(word))return sfor file in files:f = open(file_path + file,encoding='utf-8').readlines()for i in range(1,len(f)):s1,s2 = f[i][2:].strip('\n').split('\t')vocab.extend(lst_gram(s1))vocab.extend(lst_gram(s2))vocab = set(vocab) vocab_list = ['[PAD]','[UNK]'] vocab_list.extend(list(vocab))vocab_file = "save_vocab_mrpc.txt" with open(vocab_file,'w',encoding='utf-8') as f:for slice in vocab_file:f.write(slice)f.write('\n')import numpy as np import pandas as pd import torchdef load_vocab():vocab = open(vocab_file,encoding='utf-8').readlines()slice2idx,idx2slice,cnt = {},{},0for char in vocab:char = char.strip('\n')slice2idx[char] = cntidx2slice[cnt] = charcnt+=1return slice2idx,idx2slicedef padding(text, maxlen=70):pad_text = []for sentence in text:pad_sentence = np.zeros(maxlen).astype('int64')cnt = 0for index in sentence:pad_sentence[cnt] = indexcnt += 1if cnt == maxlen:breakpad_text.append(pad_sentence.tolist())return pad_textdef char_index(text_a,text_b):slice2idx,idx2slice = load_vocab()a_list,b_list = [],[]for a_sentence,b_sentence in zip(text_a,text_b):a,b = [],[]for slice in lst_gram(a_sentence):if slice in slice2idx.keys():a.append(slice2idx[slice])else:a.append(1)for slice in lst_gram(b_sentence):if slice in slice2idx.keys():b.append(slice2idx[slice])else:b.append(1)a_list.append(a)b_list.append(b)a_list = padding(a_list)b_list = padding(b_list)return a_list,b_listdef load_char_data(file_name):import pandas as pddf = pd.read_csv(file_name,sep='\t')text_a = df['#1 string'].valuestext_b = df['#2 string'].valueslabel = df['quality'].valuesa_index,b_index = char_index(text_a,text_b)return np.array(a_index),np.array(b_index),np.array(label)a_index,b_index,label = load_char_data('./MRPC/test_data.csv')?
""" 模型構建部分 """import torch import torch.nn as nn from torch.utils.data import DataLoader,Dataset from torch.autograd import Variableclass DSSM(torch.nn.Module):def __init__(self):super(DSSM,self).__init__()self.embedding = nn.Embedding(CHAR_SIZE,embedding_size)self.linear1 = nn.Linear(embedding_size,256)self.linear2 = nn.Linear(256,128)self.linear3 = nn.Linear(128,64)self.dropout = nn.Dropout(p=0.2)def forward(self,a,b):a = self.embedding(a).sum(1)b = self.embedding(b).sum(1)a = torch.tanh(self.linear1(a))a = self.dropout(a)a = torch.tanh(self.linear2(a))a = self.dropout(a)a = torch.tanh(self.linear3(a))a = self.dropout(a)b = torch.tanh(self.linear1(b))b = self.dropout(b)b = torch.tanh(self.linear2(b))b = self.dropout(b)b = torch.tanh(self.linear3(b))b = self.dropout(b)cosine = torch.cosine_similarity(a,b,dim=1,eps=1e-8)return cosinedef _initialize_weights(self):for m in self.modules():if isinstance(m,nn.Linear):torch.nn.init.xavier_uniform_(m.weight)class MRPCDataset(Dataset):def __init__(self,filepath):self.path = filepathself.a_index,self.b_index,self.label = load_char_data(filepath)def __len__(self):return len(self.a_index)def __getitem__(self,idx):return self.a_index[idx],self.b_index[idx],self.label[idx]?
""" 模型訓練部分 """CHAR_SIZE=10041 embedding_size=300EPOCH=50 BATCH_SIZE=50 LR=0.0005data_root='./MRPC/' train_path=data_root+'train_data.csv' test_path=data_root+'test_data.csv'#1、創建數據集并創立數據載入器 train_data=MRPCDataset(train_path) test_data=MRPCDataset(test_path) train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True) test_loader=DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True)#2、有gpu用gpu,否則cpu device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") dssm=DSSM().to(device) dssm._initialize_weights()#3、定義優化方式和損失函數 optimizer=torch.optim.Adam(dssm.parameters(),lr=LR) loss_func=nn.CrossEntropyLoss()for epoch in range(EPOCH):for step,(text_a,text_b,label) in enumerate(train_loader):#1、把索引轉化為tensor變量,載入設備,注意轉化成long tensora=Variable(text_a.to(device).long())b=Variable(text_b.to(device).long())l=Variable(torch.LongTensor(label).to(device))#2、計算余弦相似度pos_res=dssm(a,b)neg_res=1-pos_res#3、預測結果傳給lossout=torch.stack([neg_res,pos_res],1).to(device)loss=loss_func(out,l)#4、固定格式optimizer.zero_grad()loss.backward()optimizer.step()if (step+1) % 20 == 0:total=0correct=0for (test_a,test_b,test_l) in test_loader:tst_a=Variable(test_a.to(device).long())tst_b=Variable(test_b.to(device).long())tst_l=Variable(torch.LongTensor(test_l).to(device))pos_res=dssm(tst_a,tst_b)neg_res=1-pos_resout=torch.max(torch.stack([neg_res,pos_res],1).to(device),1)[1]if out.size()==tst_l.size():total+=tst_l.size(0)correct+=(out==tst_l).sum().item()print('[Epoch]:',epoch+1,'訓練loss:',loss.item())print('[Epoch]:',epoch+1,'測試集準確率: ',(correct*1.0/total))torch.save(dssm, './dssm.pkl')?
總結
以上是生活随笔為你收集整理的Learning Deep Structured Semantic Models for Web Search using Clickthrough Data (DSSM)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Attention Is All You
- 下一篇: Albert: A lite bert