首页 > Python资料 博客日记
Python GraphSAGE原理与代码详解,GraphSAGE代码实战,图神经网络,图计算,GraphSAGE代码模版,超简单学习GraphSAGE
2025-01-10 20:00:05Python资料围观1次
1.GraphSAGE简介
GraphSAGE(Graph Sample and Aggregated)是一种用于图节点嵌入学习的图神经网络模型。它通过采样和聚合的方式,将邻居节点的信息聚合到目标节点上,从而学习节点的表示向量。
GraphSAGE的核心思想是从目标节点的邻居节点中采样一部分节点,然后通过聚合操作将邻居节点的特征信息整合到目标节点上。这样一方面减少了计算复杂度,另一方面也保留了图结构中的信息。GraphSAGE模型包含以下几个重要的步骤:
采样:针对每个节点,从其邻居节点中随机采样一定数量的节点作为采样节点集合。采样的目的是探索节点的局部结构,以便更好地捕捉节点的特征。
编码器:对于每个节点,通过一个编码器将其自身的特征向量转换为一个低维的表示向量。编码器可以是一个全连接层、一个卷积神经网络等。
聚合器:通过聚合操作将采样节点集合中的特征向量聚合到目标节点上。常用的聚合方法有均值聚合、最大池化等。聚合的过程可以通过多层的聚合器进行迭代。
通过多层的编码器和聚合器,GraphSAGE能够逐渐聚合更多层次的邻居节点信息,并且逐渐扩大目标节点对邻居节点的感知范围。最终,每个节点都能够获得一个表示其自身和周围结构的嵌入向量,该向量可以用于下游的节点分类、链接预测等任务。
GraphSAGE在图节点嵌入学习任务中具有较好的性能,能够有效地学习图结构中的节点特征。它可以用于社交网络分析、推荐系统、图像分析等领域,对于挖掘和分析图结构数据具有重要的应用价值。
2.代码讲解与实战
采样代码:
def sample(self, inputs, layer_infos, batch_size=None):
""" Sample neighbors to be the supportive fields for multi-layer convolutions.
Args:
inputs: batch inputs
batch_size: the number of inputs (different for batch inputs and negative samples).
"""
if batch_size is None:
batch_size = self.batch_size
samples = [inputs]
# size of convolution support at each layer per node
support_size = 1
support_sizes = [support_size]
for k in range(len(layer_infos)):
t = len(layer_infos) - k - 1
support_size *= layer_infos[t].num_samples
sampler = layer_infos[t].neigh_sampler
node = sampler((samples[k], layer_infos[t].num_samples))
samples.append(tf.reshape(node, [support_size * batch_size,]))
support_sizes.append(support_size)
return samples, support_sizes
SAGEInfo = namedtuple("SAGEInfo",
['layer_name', # name of the layer (to get feature embedding etc.)
'neigh_sampler', # callable neigh_sampler constructor
'num_samples',
'output_dim' # the output (i.e., hidden) dimension
])
采样过程其对应的源码在
model.py
的sample
函数,函数的入参layer_infos
是由SAGEInfo
元祖组成的list,SAGEInfo
中的neigh_sampler
表示抽样算法,源码中使用的是均匀采样,因为每一层都会选择一组SAGEInfo
,因此每一层是可以使用不同的采样器的。num_samples
是当前层的采样的邻居数。在源码中有两个index,其中
k
的顺序是从1到 K,用来拼接各层采样的节点组成的list,t
是k
的逆序,用于确定采样函数和样本数等超参。变量support_size
是当前层要采样的样本数,因为第 K−1 层是在 K层的基础上发散得到的,因此需要进行乘法的叠加。最终函数返回的是采样点的samples
数组和各层的节点数目support_sizes
数组。
聚合代码:
def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None,
aggregators=None, name=None, concat=False, model_size="small"):
""" At each layer, aggregate hidden representations of neighbors to compute the hidden representations
at next layer.
Args:
samples: a list of samples of variable hops away for convolving at each layer of the
network. Length is the number of layers + 1. Each is a vector of node indices.
input_features: the input features for each sample of various hops away.
dims: a list of dimensions of the hidden representations from the input layer to the
final layer. Length is the number of layers + 1.
num_samples: list of number of samples for each layer.
support_sizes: the number of nodes to gather information from for each layer.
batch_size: the number of inputs (different for batch inputs and negative samples).
Returns:
The hidden representation at the final layer for all nodes in batch
"""
if batch_size is None:
batch_size = self.batch_size
# length: number of layers + 1
hidden = [tf.nn.embedding_lookup(input_features, node_samples) for node_samples in samples]
new_agg = aggregators is None
if new_agg:
aggregators = []
for layer in range(len(num_samples)):
if new_agg:
dim_mult = 2 if concat and (layer != 0) else 1
# aggregator at current layer
if layer == len(num_samples) - 1:
aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1], act=lambda x : x,
dropout=self.placeholders['dropout'],
name=name, concat=concat, model_size=model_size)
else:
aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1],
dropout=self.placeholders['dropout'],
name=name, concat=concat, model_size=model_size)
aggregators.append(aggregator)
else:
aggregator = aggregators[layer]
# hidden representation at current layer for all support nodes that are various hops away
next_hidden = []
# as layer increases, the number of support nodes needed decreases
for hop in range(len(num_samples) - layer):
dim_mult = 2 if concat and (layer != 0) else 1
neigh_dims = [batch_size * support_sizes[hop],
num_samples[len(num_samples) - hop - 1],
dim_mult*dims[layer]]
h = aggregator((hidden[hop],
tf.reshape(hidden[hop + 1], neigh_dims)))
next_hidden.append(h)
hidden = next_hidden
return hidden[0], aggregators
代码模版:代码直接调用采样和聚合过程的模块SAGEConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader
# 定义GraphSAGE模型
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
return F.log_softmax(x, dim=1)
# 创建数据集
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]
# 创建模型和优化器
model = GraphSAGE(in_channels=dataset.num_features, hidden_channels=16, out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 划分训练集、验证集和测试集
num_nodes = data.num_nodes
indices = torch.randperm(num_nodes)
train_mask = indices[:int(0.8 * num_nodes)]
val_mask = indices[int(0.8 * num_nodes):int(0.9 * num_nodes)]
test_mask = indices[int(0.9 * num_nodes):]
def train():
model.train()
optimizer.zero_grad()
output = model(data.x, data.edge_index)[train_mask]
loss = F.nll_loss(output, data.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test():
model.eval()
logits, accs = model(data.x, data.edge_index), []
for mask in [train_mask, val_mask, test_mask]:
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs
# 训练模型
for epoch in range(1, 201):
loss = train()
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
标签:
相关文章
最新发布
- 华为OD机试E卷 --分割数组的最大差值--24年OD统一考试(Java & JS & Python & C & C++)
- Python球球大作战
- Python GraphSAGE原理与代码详解,GraphSAGE代码实战,图神经网络,图计算,GraphSAGE代码模版,超简单学习GraphSAGE
- Python酷库之旅-第三方库Pandas(060)
- 【Python】BeautifulSoup:HTML解析
- python、tensorflow、keras的版本对应关系
- Linux上python离线安装教程
- 计算机毕业设计Python+CNN卷积神经网络小说推荐系统 K-means聚类推荐算法 深度学习 Kears 小说数据分析 可视化 Scrapy爬虫 协同过滤
- 华为OD机试E卷 --树状结构查询--24年OD统一考试(Java & JS & Python & C & C++)
- ERROR: Failed building wheel for llama-cpp-python解决方案
点击排行
- 版本匹配指南:Numpy版本和Python版本的对应关系
- 版本匹配指南:PyTorch版本、torchvision 版本和Python版本的对应关系
- Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO
- 相关性分析——Pearson相关系数+热力图(附data和Python完整代码)
- Anaconda版本和Python版本对应关系(持续更新...)
- Python与PyTorch的版本对应
- Windows上安装 Python 环境并配置环境变量 (超详细教程)
- Python pyinstaller打包exe最完整教程