mindspore_gl.nn.GMMConv

查看源文件
class mindspore_gl.nn.GMMConv(in_feat_size: int, out_feat_size: int, coord_dim: int, n_kernels: int, residual=False, bias=False, aggregator_type='sum')[源代码]

高斯混合模型卷积层。 来自论文 Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs

\[\begin{split}u_{ij} = f(x_i, x_j), x_j \in \mathcal{N}(i) \\ w_k(u) = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right) \\ h_i^{l+1} = \mathrm{aggregate}\left(\left\{\frac{1}{K} \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)\end{split}\]

其中 \(u\) 表示顶点与它其中一个邻居之间的伪坐标,使用 函数 \(f\) ,其中 \(\Sigma_k^{-1}\)\(\mu_k\) 是协方差的可学习参数矩阵和高斯核的均值向量。

参数:
  • in_feat_size (int) - 输入节点特征大小。

  • out_feat_size (int) - 输出节点特征大小。

  • coord_dim (int) - 伪坐标的维度。

  • n_kernels (int) - 内核数。

  • residual (bool, 可选) - 是否使用残差。默认值:False

  • bias (bool, 可选) - 是否使用偏置。默认值:False

  • aggregator_type (str, 可选) - 聚合器的类型。默认值:'sum'

输入:
  • x (Tensor) - 输入节点特征。Shape为 \((N, D_{in})\) ,其中 \(N\) 是节点数, \(D_{in}\) 应等于参数中的 in_feat_size

  • pseudo (Tensor) - 伪坐标张量。

  • g (Graph) - 输入图。

输出:
  • Tensor,Shape为 \((N, D_{out})\) 应等于参数中的 out_size

异常:
  • SyntaxError - 当 aggregation_type 不等于 'sum' 时。

  • TypeError - 如果 in_feat_sizeout_feat_sizecoord_dimn_kernels 不是int。

  • TypeError - 如果 biasresual 不是bool。

支持平台:

Ascend GPU

样例:

>>> import mindspore as ms
>>> from mindspore_gl.nn import GMMConv
>>> from mindspore_gl import GraphField
>>> n_nodes = 4
>>> n_edges = 7
>>> node_feat_size = 7
>>> src_idx = ms.Tensor([0, 1, 1, 2, 2, 3, 3], ms.int32)
>>> dst_idx = ms.Tensor([0, 0, 2, 1, 3, 0, 1], ms.int32)
>>> ones = ms.ops.Ones()
>>> node_feat = ones((n_nodes, node_feat_size), ms.float32)
>>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> meanconv = GMMConv(in_feat_size=node_feat_size, out_feat_size=2, coord_dim=3, n_kernels=2)
>>> pseudo = ones((7, 3), ms.float32)
>>> res = meanconv(node_feat, pseudo, *graph_field.get_graph())
>>> print(res.shape)
(4, 2)