mindspore_gl.nn.SAGPooling

查看源文件
class mindspore_gl.nn.SAGPooling(in_channels: int, GNN=GCNConv2, activation=ms.nn.Tanh, multiplier=1.0)[源代码]

基于self-attention的池化操作。来自 Self-Attention Graph PoolingUnderstanding Attention and Generalization in Graph Neural Networks

\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]
参数:
  • in_channels (int) - 每个输入样本的大小。

  • GNN (GNNCell, 可选) - 用于计算投影分数的图神经网络层,仅支持GCNConv2。默认值:mindspore_gl.nn.con.GCNConv2

  • activation (Cell, 可选) - 非线性激活函数。默认值:mindspore.nn.Tanh

  • multiplier (float, 可选) - 用于缩放节点功能的标量。默认值:1.0

输入:
  • x (Tensor) - 要更新的输入节点特征。Shape为 \((N, D)\) 其中 \(N\) 是节点数, \(D\) 是节点的特征大小,当 attn 为None时,D 应等于参数中的 in_feat_size

  • attn (Tensor) - 用于计算投影分数的输入节点特征。Shape为 \((N, D_{in})\) 其中 \(N\) 是节点数, \(D_{in}\) 应等于参数中的 in_feat_size 。 如果用 x 计算投影分数, attn 可以为None。

  • node_num (Int) - 以图g中的节点总数。

  • perm_num (Int) - Topk个节点过滤中k值。

  • g (BatchedGraph) - 输入图。

输出:
  • x (Tensor) - 更新的节点特征。Shape为 \((2, M, D_{out})\) 其中 \(M\) 等于 Inputs 中的 perm_num\(D_{out}\) 等于 Inputs 中的 D

  • src_perm (Tensor) - 更新的src节点。

  • dst_perm (Tensor) - 更新的dst节点。

  • perm (Tensor) - 更新节点索引之前topk节点的节点索引。Shape为 \(M\),其中 \(M\) 等于 Inputs 中的 perm_num

  • perm_score (Tensor) - 更新节点的投影分数。

异常:
  • TypeError - 如果 in_feat_sizeout_size 不是int。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore_gl.nn import SAGPooling
>>> from mindspore_gl import BatchedGraphField
>>> node_feat = ms.Tensor([[1, 2, 3, 4], [2, 4, 1, 3], [1, 3, 2, 4],
...                        [9, 7, 5, 8], [8, 7, 6, 5], [8, 6, 4, 6], [1, 2, 1, 1]],
...                       ms.float32)
>>> n_nodes = 7
>>> n_edges = 8
>>> src_idx = ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6], ms.int32)
>>> dst_idx = ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4], ms.int32)
>>> ver_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1], ms.int32)
>>> edge_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1, 1], ms.int32)
>>> graph_mask = ms.Tensor([0, 1], ms.int32)
>>> batched_graph_field = BatchedGraphField(src_idx, dst_idx, n_nodes, n_edges, ver_subgraph_idx,
...                                         edge_subgraph_idx, graph_mask)
>>> net = SAGPooling(4)
>>> feature, src, dst, ver_subgraph, edge_subgraph, perm, perm_score = net(node_feat, None, 2,
...                                                                    *batched_graph_field.get_batched_graph())
>>> print(feature.shape)
(2, 2, 4)