DGL教程【三】构建自己的GNN模块
有時,利用現有的GNN模型進行堆疊無法滿足我們的需求,例如我們希望通過考慮節點重要性或邊權值來發明一種聚合鄰居信息的新方法。
本節將介紹:
- DGL的消息傳遞API
- 自己實現一個GraphSage卷積模型
消息傳遞GNN
DGL遵循消息傳遞范式,很多GNN模型往往都遵循下面的這個架構:
DGL 稱M(l)M^{(l)}M(l)為一個消息函數,∑\sum∑是一個聚合函數,U(l)U^{(l)}U(l)是一個更新函數。
需要注意的是這里的∑\sum∑可以代表任意一個方法,而不僅僅是一個求和函數。
例如大名鼎鼎的GraphSage使用了下面的公式:
我們可以看出來消息傳遞是有方向的:消息從一個節點u傳遞到另一個節點v,與消息從節點v傳遞到節點u 不一定是一樣的。
盡管DGL已經通過dgl.nn.SAGEConv內置了GraphSAGE,這里你依然可以通過自己來實現GraphSAGE:
import dgl.function as fnclass SAGEConv(nn.Module):"""Graph convolution module used by the GraphSAGE model.Parameters----------in_feat : intInput feature size.out_feat : intOutput feature size."""def __init__(self, in_feat, out_feat):super(SAGEConv, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = nn.Linear(in_feat * 2, out_feat)def forward(self, g, h):"""Forward computationParameters----------g : GraphThe input graph.h : TensorThe input node feature."""with g.local_scope():g.ndata['h'] = h# update_all is a message passing API.g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))h_N = g.ndata['h_N']h_total = torch.cat([h, h_N], dim=1)return self.linear(h_total)這段代碼的核心部分就是g.update_all方法,目的是對周圍鄰居特征進行聚合。
- 消息傳遞方法fn.copy_u('h', 'm')的作用是復制節點屬性h作為特征傳遞給鄰居信息
- 聚合方法fn.mean('m', 'h_N')會將收到的信息m進行平均,讓后保存到新的屬性中h_N
- update_all告訴DGL觸發所有節點和邊的消息傳遞和信息聚合模塊
然后,你可以對GraphSAGE進行堆疊來構建一個多層的GraphSAGE網絡。
class Model(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(Model, self).__init__()self.conv1 = SAGEConv(in_feats, h_feats)self.conv2 = SAGEConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h訓練
下面的代碼可以直接從之前的教程獲得:
import dgl.datadataset = dgl.data.CoraGraphDataset() g = dataset[0]def train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)all_logits = []best_val_acc = 0best_test_acc = 0features = g.ndata['feat']labels = g.ndata['label']train_mask = g.ndata['train_mask']val_mask = g.ndata['val_mask']test_mask = g.ndata['test_mask']for e in range(200):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that we should only compute the losses of the nodes in the training set,# i.e. with train_mask 1.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()all_logits.append(logits.detach())if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_acc, best_val_acc, test_acc, best_test_acc))model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes) train(g, model)輸出結果:
In epoch 0, loss: 1.949, val acc: 0.122 (best 0.122), test acc: 0.130 (best 0.130) In epoch 5, loss: 1.872, val acc: 0.326 (best 0.326), test acc: 0.347 (best 0.347) In epoch 10, loss: 1.740, val acc: 0.386 (best 0.386), test acc: 0.424 (best 0.424) In epoch 15, loss: 1.545, val acc: 0.460 (best 0.460), test acc: 0.495 (best 0.495) In epoch 20, loss: 1.291, val acc: 0.536 (best 0.536), test acc: 0.575 (best 0.575) In epoch 25, loss: 0.993, val acc: 0.620 (best 0.620), test acc: 0.653 (best 0.653) In epoch 30, loss: 0.691, val acc: 0.682 (best 0.682), test acc: 0.690 (best 0.690) In epoch 35, loss: 0.435, val acc: 0.728 (best 0.728), test acc: 0.721 (best 0.721) In epoch 40, loss: 0.255, val acc: 0.742 (best 0.742), test acc: 0.747 (best 0.747) In epoch 45, loss: 0.145, val acc: 0.738 (best 0.742), test acc: 0.751 (best 0.747) In epoch 50, loss: 0.084, val acc: 0.740 (best 0.742), test acc: 0.756 (best 0.747) In epoch 55, loss: 0.051, val acc: 0.744 (best 0.746), test acc: 0.759 (best 0.759) In epoch 60, loss: 0.034, val acc: 0.752 (best 0.752), test acc: 0.762 (best 0.762) In epoch 65, loss: 0.024, val acc: 0.752 (best 0.752), test acc: 0.765 (best 0.762) In epoch 70, loss: 0.018, val acc: 0.754 (best 0.754), test acc: 0.769 (best 0.767) In epoch 75, loss: 0.014, val acc: 0.756 (best 0.756), test acc: 0.772 (best 0.772) In epoch 80, loss: 0.012, val acc: 0.758 (best 0.758), test acc: 0.770 (best 0.772) In epoch 85, loss: 0.010, val acc: 0.758 (best 0.758), test acc: 0.769 (best 0.772) In epoch 90, loss: 0.009, val acc: 0.760 (best 0.760), test acc: 0.772 (best 0.770) In epoch 95, loss: 0.008, val acc: 0.760 (best 0.760), test acc: 0.773 (best 0.770) In epoch 100, loss: 0.007, val acc: 0.762 (best 0.762), test acc: 0.770 (best 0.772) In epoch 105, loss: 0.007, val acc: 0.762 (best 0.762), test acc: 0.769 (best 0.772) In epoch 110, loss: 0.006, val acc: 0.762 (best 0.762), test acc: 0.769 (best 0.772) In epoch 115, loss: 0.006, val acc: 0.760 (best 0.762), test acc: 0.770 (best 0.772) In epoch 120, loss: 0.005, val acc: 0.760 (best 0.762), test acc: 0.769 (best 0.772) In epoch 125, loss: 0.005, val acc: 0.758 (best 0.762), test acc: 0.769 (best 0.772) In epoch 130, loss: 0.005, val acc: 0.758 (best 0.762), test acc: 0.769 (best 0.772) In epoch 135, loss: 0.004, val acc: 0.758 (best 0.762), test acc: 0.768 (best 0.772) In epoch 140, loss: 0.004, val acc: 0.758 (best 0.762), test acc: 0.768 (best 0.772) In epoch 145, loss: 0.004, val acc: 0.758 (best 0.762), test acc: 0.768 (best 0.772) In epoch 150, loss: 0.004, val acc: 0.758 (best 0.762), test acc: 0.768 (best 0.772) In epoch 155, loss: 0.004, val acc: 0.756 (best 0.762), test acc: 0.769 (best 0.772) In epoch 160, loss: 0.003, val acc: 0.758 (best 0.762), test acc: 0.771 (best 0.772) In epoch 165, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772) In epoch 170, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.773 (best 0.772) In epoch 175, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772) In epoch 180, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772) In epoch 185, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772) In epoch 190, loss: 0.003, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772) In epoch 195, loss: 0.002, val acc: 0.756 (best 0.762), test acc: 0.772 (best 0.772)定制設置
在DGL中,我們在dgl.function包中提供了很多內置的消息傳遞與聚合的函數。利用這些函數可以構建一個自定義的卷積模塊,例如下面的代碼構建了一個SAGEConv,通過加權平均的方式來聚合鄰居信息,同時消息傳遞中可以包含edata對應的邊的信息。
class WeightedSAGEConv(nn.Module):"""Graph convolution module used by the GraphSAGE model with edge weights.Parameters----------in_feat : intInput feature size.out_feat : intOutput feature size."""def __init__(self, in_feat, out_feat):super(WeightedSAGEConv, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = nn.Linear(in_feat * 2, out_feat)def forward(self, g, h, w):"""Forward computationParameters----------g : GraphThe input graph.h : TensorThe input node feature.w : TensorThe edge weight."""with g.local_scope():g.ndata['h'] = hg.edata['w'] = wg.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))h_N = g.ndata['h_N']h_total = torch.cat([h, h_N], dim=1)return self.linear(h_total)因為我們當前的DataSet沒有邊上的權重,我們需要手動添加一個全one的邊權重,你可以根據具體情況進行設置:
class Model(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(Model, self).__init__()self.conv1 = WeightedSAGEConv(in_feats, h_feats)self.conv2 = WeightedSAGEConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))h = F.relu(h)h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))return hmodel = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes) train(g, model)輸出如下:
In epoch 0, loss: 1.952, val acc: 0.102 (best 0.102), test acc: 0.082 (best 0.082) In epoch 5, loss: 1.872, val acc: 0.206 (best 0.208), test acc: 0.212 (best 0.194) In epoch 10, loss: 1.721, val acc: 0.428 (best 0.542), test acc: 0.449 (best 0.561) In epoch 15, loss: 1.498, val acc: 0.424 (best 0.542), test acc: 0.439 (best 0.561) In epoch 20, loss: 1.216, val acc: 0.552 (best 0.552), test acc: 0.560 (best 0.560) In epoch 25, loss: 0.906, val acc: 0.656 (best 0.656), test acc: 0.655 (best 0.655) In epoch 30, loss: 0.618, val acc: 0.696 (best 0.696), test acc: 0.717 (best 0.717) In epoch 35, loss: 0.390, val acc: 0.722 (best 0.722), test acc: 0.741 (best 0.741) In epoch 40, loss: 0.235, val acc: 0.718 (best 0.722), test acc: 0.748 (best 0.741) In epoch 45, loss: 0.141, val acc: 0.722 (best 0.722), test acc: 0.755 (best 0.741) In epoch 50, loss: 0.087, val acc: 0.730 (best 0.730), test acc: 0.757 (best 0.756) In epoch 55, loss: 0.056, val acc: 0.728 (best 0.730), test acc: 0.761 (best 0.756) In epoch 60, loss: 0.038, val acc: 0.728 (best 0.730), test acc: 0.758 (best 0.756) In epoch 65, loss: 0.028, val acc: 0.728 (best 0.730), test acc: 0.756 (best 0.756) In epoch 70, loss: 0.021, val acc: 0.732 (best 0.732), test acc: 0.756 (best 0.756) In epoch 75, loss: 0.017, val acc: 0.732 (best 0.732), test acc: 0.756 (best 0.756) In epoch 80, loss: 0.015, val acc: 0.732 (best 0.732), test acc: 0.753 (best 0.756) In epoch 85, loss: 0.013, val acc: 0.732 (best 0.732), test acc: 0.753 (best 0.756) In epoch 90, loss: 0.011, val acc: 0.732 (best 0.734), test acc: 0.754 (best 0.754) In epoch 95, loss: 0.010, val acc: 0.732 (best 0.734), test acc: 0.754 (best 0.754) In epoch 100, loss: 0.009, val acc: 0.734 (best 0.734), test acc: 0.753 (best 0.754) In epoch 105, loss: 0.008, val acc: 0.732 (best 0.734), test acc: 0.754 (best 0.754) In epoch 110, loss: 0.008, val acc: 0.732 (best 0.734), test acc: 0.754 (best 0.754) In epoch 115, loss: 0.007, val acc: 0.736 (best 0.736), test acc: 0.753 (best 0.753) In epoch 120, loss: 0.006, val acc: 0.736 (best 0.736), test acc: 0.754 (best 0.753) In epoch 125, loss: 0.006, val acc: 0.736 (best 0.736), test acc: 0.755 (best 0.753) In epoch 130, loss: 0.006, val acc: 0.736 (best 0.736), test acc: 0.756 (best 0.753) In epoch 135, loss: 0.005, val acc: 0.736 (best 0.736), test acc: 0.756 (best 0.753) In epoch 140, loss: 0.005, val acc: 0.736 (best 0.736), test acc: 0.756 (best 0.753) In epoch 145, loss: 0.005, val acc: 0.738 (best 0.738), test acc: 0.756 (best 0.756) In epoch 150, loss: 0.004, val acc: 0.738 (best 0.738), test acc: 0.757 (best 0.756) In epoch 155, loss: 0.004, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 160, loss: 0.004, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 165, loss: 0.004, val acc: 0.738 (best 0.738), test acc: 0.757 (best 0.756) In epoch 170, loss: 0.004, val acc: 0.738 (best 0.738), test acc: 0.757 (best 0.756) In epoch 175, loss: 0.003, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 180, loss: 0.003, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 185, loss: 0.003, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 190, loss: 0.003, val acc: 0.738 (best 0.738), test acc: 0.758 (best 0.756) In epoch 195, loss: 0.003, val acc: 0.740 (best 0.740), test acc: 0.758 (best 0.758)Process finished with exit code 0總結
以上是生活随笔為你收集整理的DGL教程【三】构建自己的GNN模块的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 三国杀十周年在哪?
- 下一篇: DGL教程【四】使用GNN进行链路预测