mindspore_gl.nn.SAGPooling

View Source On Gitee
class mindspore_gl.nn.SAGPooling(in_channels: int, GNN=GCNConv2, activation=ms.nn.Tanh, multiplier=1.0)[source]

The self-attention pooling operator. From the Self-Attention Graph Pooling and Understanding Attention and Generalization in Graph Neural Networks papers.

\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]
Parameters
  • in_channels (int) – Size of each input sample.

  • GNN (GNNCell, optional) – A graph neural network layer for calculating projection scores. only GCNConv2 is supported. Default: mindspore_gl.nn.conv.GCNConv2.

  • activation (Cell, optional) – The nonlinearity activation function Cell to use. Default: mindspore.nn.Tanh.

  • multiplier (float, optional) – A scalar for scaling node feature. Default: 1.

Inputs:
  • x (Tensor) - The input node features to be updated. The shape is \((N, D)\) where \(N\) is the number of nodes, and \(D\) is the feature size of nodes, when attn is None, D should be equal to in_feat_size in Args.

  • attn (Tensor) - The input node features for calculating projection score. The shape is \((N, D_{in})\) where \(N\) is the number of nodes, and \(D_{in}\) should be equal to in_feat_size in Args. attn can be None, if x is expected to be used for calculating projection score.

  • node_num (Int) - total number of nodes in g.

  • perm_num (Int) - expected k for topk nodes filtering.

  • g (BatchedGraph) - The input graph.

Outputs:
  • x (Tensor) - The updated node features. The shape is \((2, M, D_{out})\), where \(M\) equals to perm_num in Inputs, and \(D_{out}\) equals to D in Inputs.

  • src_perm (Tensor) - The updated source nodes.

  • dst_perm (Tensor) - The updated destination nodes.

  • perm (Tensor) - The node index for topk nodes before updating node index. The shape is \(M\), where \(M\) equals to perm_num in Inputs.

  • perm_score (Tensor) - The projection score for updated nodes.

Raises

TypeError – If in_feat_size or out_size is not an int.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore_gl.nn import SAGPooling
>>> from mindspore_gl import BatchedGraphField
>>> node_feat = ms.Tensor([[1, 2, 3, 4], [2, 4, 1, 3], [1, 3, 2, 4],
...                        [9, 7, 5, 8], [8, 7, 6, 5], [8, 6, 4, 6], [1, 2, 1, 1]],
...                       ms.float32)
>>> n_nodes = 7
>>> n_edges = 8
>>> src_idx = ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6], ms.int32)
>>> dst_idx = ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4], ms.int32)
>>> ver_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1], ms.int32)
>>> edge_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1, 1], ms.int32)
>>> graph_mask = ms.Tensor([0, 1], ms.int32)
>>> batched_graph_field = BatchedGraphField(src_idx, dst_idx, n_nodes, n_edges, ver_subgraph_idx,
...                                         edge_subgraph_idx, graph_mask)
>>> net = SAGPooling(4)
>>> feature, src, dst, ver_subgraph, edge_subgraph, perm, perm_score = net(node_feat, None, 2,
...                                                                    *batched_graph_field.get_batched_graph())
>>> print(feature.shape)
(2, 2, 4)