首页 > Python资料 博客日记

Python GraphSAGE原理与代码详解,GraphSAGE代码实战,图神经网络,图计算,GraphSAGE代码模版,超简单学习GraphSAGE

2025-01-10 20:00:05Python资料围观1

这篇文章介绍了Python GraphSAGE原理与代码详解,GraphSAGE代码实战,图神经网络,图计算,GraphSAGE代码模版,超简单学习GraphSAGE,分享给大家做个参考,收藏Python资料网收获更多编程知识

1.GraphSAGE简介

GraphSAGE(Graph Sample and Aggregated)是一种用于图节点嵌入学习的图神经网络模型。它通过采样和聚合的方式,将邻居节点的信息聚合到目标节点上,从而学习节点的表示向量。

GraphSAGE的核心思想是从目标节点的邻居节点中采样一部分节点,然后通过聚合操作将邻居节点的特征信息整合到目标节点上。这样一方面减少了计算复杂度,另一方面也保留了图结构中的信息。GraphSAGE模型包含以下几个重要的步骤:

  1. 采样:针对每个节点,从其邻居节点中随机采样一定数量的节点作为采样节点集合。采样的目的是探索节点的局部结构,以便更好地捕捉节点的特征。

  2. 编码器:对于每个节点,通过一个编码器将其自身的特征向量转换为一个低维的表示向量。编码器可以是一个全连接层、一个卷积神经网络等。

  3. 聚合器:通过聚合操作将采样节点集合中的特征向量聚合到目标节点上。常用的聚合方法有均值聚合、最大池化等。聚合的过程可以通过多层的聚合器进行迭代。

通过多层的编码器和聚合器,GraphSAGE能够逐渐聚合更多层次的邻居节点信息,并且逐渐扩大目标节点对邻居节点的感知范围。最终,每个节点都能够获得一个表示其自身和周围结构的嵌入向量,该向量可以用于下游的节点分类、链接预测等任务。

GraphSAGE在图节点嵌入学习任务中具有较好的性能,能够有效地学习图结构中的节点特征。它可以用于社交网络分析、推荐系统、图像分析等领域,对于挖掘和分析图结构数据具有重要的应用价值。

 

2.代码讲解与实战

采样代码:

def sample(self, inputs, layer_infos, batch_size=None):
      """ Sample neighbors to be the supportive fields for multi-layer convolutions.
      Args:
          inputs: batch inputs
          batch_size: the number of inputs (different for batch inputs and negative samples).
      """

      if batch_size is None:
          batch_size = self.batch_size
      samples = [inputs]
      # size of convolution support at each layer per node
      support_size = 1
      support_sizes = [support_size]
      for k in range(len(layer_infos)):
          t = len(layer_infos) - k - 1
          support_size *= layer_infos[t].num_samples
          sampler = layer_infos[t].neigh_sampler
          node = sampler((samples[k], layer_infos[t].num_samples))
          samples.append(tf.reshape(node, [support_size * batch_size,]))
          support_sizes.append(support_size)
      return samples, support_sizes
SAGEInfo = namedtuple("SAGEInfo",
    ['layer_name', # name of the layer (to get feature embedding etc.)
     'neigh_sampler', # callable neigh_sampler constructor
     'num_samples',
     'output_dim' # the output (i.e., hidden) dimension
    ])

采样过程其对应的源码在model.pysample函数,函数的入参layer_infos是由SAGEInfo元祖组成的list,SAGEInfo中的neigh_sampler表示抽样算法,源码中使用的是均匀采样,因为每一层都会选择一组SAGEInfo,因此每一层是可以使用不同的采样器的。num_samples是当前层的采样的邻居数。

在源码中有两个index,其中k的顺序是从1到 K,用来拼接各层采样的节点组成的list,tk的逆序,用于确定采样函数和样本数等超参。变量support_size是当前层要采样的样本数,因为第 K−1 层是在 K层的基础上发散得到的,因此需要进行乘法的叠加。最终函数返回的是采样点的samples数组和各层的节点数目support_sizes数组。

聚合代码:

def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None,
            aggregators=None, name=None, concat=False, model_size="small"):
        """ At each layer, aggregate hidden representations of neighbors to compute the hidden representations 
            at next layer.
        Args:
            samples: a list of samples of variable hops away for convolving at each layer of the
                network. Length is the number of layers + 1. Each is a vector of node indices.
            input_features: the input features for each sample of various hops away.
            dims: a list of dimensions of the hidden representations from the input layer to the
                final layer. Length is the number of layers + 1.
            num_samples: list of number of samples for each layer.
            support_sizes: the number of nodes to gather information from for each layer.
            batch_size: the number of inputs (different for batch inputs and negative samples).
        Returns:
            The hidden representation at the final layer for all nodes in batch
        """

        if batch_size is None:
            batch_size = self.batch_size

        # length: number of layers + 1
        hidden = [tf.nn.embedding_lookup(input_features, node_samples) for node_samples in samples]
        new_agg = aggregators is None
        if new_agg:
            aggregators = []
        for layer in range(len(num_samples)):
            if new_agg:
                dim_mult = 2 if concat and (layer != 0) else 1
                # aggregator at current layer
                if layer == len(num_samples) - 1:
                    aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1], act=lambda x : x,
                            dropout=self.placeholders['dropout'], 
                            name=name, concat=concat, model_size=model_size)
                else:
                    aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1],
                            dropout=self.placeholders['dropout'], 
                            name=name, concat=concat, model_size=model_size)
                aggregators.append(aggregator)
            else:
                aggregator = aggregators[layer]
            # hidden representation at current layer for all support nodes that are various hops away
            next_hidden = []
            # as layer increases, the number of support nodes needed decreases
            for hop in range(len(num_samples) - layer):
                dim_mult = 2 if concat and (layer != 0) else 1
                neigh_dims = [batch_size * support_sizes[hop], 
                              num_samples[len(num_samples) - hop - 1], 
                              dim_mult*dims[layer]]
                h = aggregator((hidden[hop],
                                tf.reshape(hidden[hop + 1], neigh_dims)))
                next_hidden.append(h)
            hidden = next_hidden
        return hidden[0], aggregators

 代码模版:代码直接调用采样和聚合过程的模块SAGEConv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

# 定义GraphSAGE模型
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return F.log_softmax(x, dim=1)

# 创建数据集
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

# 创建模型和优化器
model = GraphSAGE(in_channels=dataset.num_features, hidden_channels=16, out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 划分训练集、验证集和测试集
num_nodes = data.num_nodes
indices = torch.randperm(num_nodes)
train_mask = indices[:int(0.8 * num_nodes)]
val_mask = indices[int(0.8 * num_nodes):int(0.9 * num_nodes)]
test_mask = indices[int(0.9 * num_nodes):]

def train():
    model.train()
    optimizer.zero_grad()
    output = model(data.x, data.edge_index)[train_mask]
    loss = F.nll_loss(output, data.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for mask in [train_mask, val_mask, test_mask]:
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

# 训练模型
for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

版权声明:本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:jacktools123@163.com进行投诉反馈,一经查实,立即删除!

标签:

相关文章

本站推荐