mindspore_gl.nn.EGConv

查看源文件
class mindspore_gl.nn.EGConv(in_feat_size: int, out_feat_size: int, aggregators: List[str], num_heads: int = 8, num_bases: int = 4, bias: bool = True)[源代码]

高效图卷积。来自论文 Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions

\[h_i^{(l+1)} = {\LARGE ||}_{h=1}^{H} \sum_{\oplus \in \mathcal{A}} \sum_{b=1}^{B} w_{h,\oplus,b}^{(l)} \bigoplus_{j \in \mathcal{N(i)}} W_{b}^{(l)} h_{j}^{(l)}\]

\(\mathcal{N}(i)\) 表示 \(i\) 的邻居节点, \(W_{b}^{(l)}\) 表示基础权重, \(\oplus\) 表示聚合器, \(w_{h,\oplus,b}^{(l)}\) 表示头部、聚合器和底部的每顶点加权系数。

参数:
  • in_feat_size (int) - 输入节点特征大小。

  • out_feat_size (int) - 输出节点特征大小。

  • aggregators (str, 可选) - 要使用的聚合器。支持的聚合器为 'sum''mean''max''min''std''var''symnorm'

  • num_heads (int, 可选) - 头数 \(H\) 。必须具有 \(out\_feat\_size % num\_heads == 0\) 。默认值:8

  • num_bases (int, 可选) - 基础权重数 \(B\) 。默认值:4

  • bias (bool, 可选) - 是否加入可学习偏置。默认值:True

输入:
  • x (Tensor) - 输入节点功能。Shape为 \((N, D_{in})\) 其中 \(N\) 是节点数,\(D_{in}\) 应等于参数中的 in_feat_size

  • g (Graph) - 输入图表。

输出:
  • Tensor,输出节点特征的Shape为 \((N, D_{out})\) 其中 \((D_{out})\) 应与参数中的 out_feat_size 相等。

异常:
  • TypeError - 如果 in_feat_sizeout_feat_sizenum_heads 不是正整数。

  • ValueError - 如果 out_feat_size 不能被 num_heads 整除。

  • ValueError - 如果 aggregators- 不为 'sum''mean''max''min''symnorm''var''std'

支持平台:

Ascend GPU

样例:

>>> import mindspore as ms
>>> from mindspore_gl.nn import EGConv
>>> from mindspore_gl import GraphField
>>> n_nodes = 4
>>> n_edges = 7
>>> feat_size = 4
>>> src_idx = ms.Tensor([0, 1, 1, 2, 2, 3, 3], ms.int32)
>>> dst_idx = ms.Tensor([0, 0, 2, 1, 3, 0, 1], ms.int32)
>>> ones = ms.ops.Ones()
>>> feat = ones((n_nodes, feat_size), ms.float32)
>>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> conv = EGConv(in_feat_size=4, out_feat_size=6, aggregators=['sum'], num_heads=3, num_bases=3)
>>> res = conv(feat, *graph_field.get_graph())
>>> print(res.shape)
(4, 6)