mindchemistry.e3.utils.radius_full

查看源文件
mindchemistry.e3.utils.radius_full(x, y, batch_x=None, batch_y=None)[源代码]

找到 x 中每个元素在 y 中的所有点。

参数:
  • x (Tensor) - x 节点特征矩阵。

  • y (Tensor) - y 节点特征矩阵。

  • batch_x (ndarray) - x 批向量。如果为 None,则根据 x 计算并返回。默认值:None

  • batch_y (ndarray) - y 批向量。如果为 None,则根据 y 计算并返回。默认值:None

返回:
  • edge_index (numpy.ndarray) - 包括边的起点与终点。

  • batch_x (numpy.ndarray) - x 批向量。

  • batch_y (numpy.ndarray) - y 批向量。

异常:
  • ValueError - 如果 xy 的最后一个维度不匹配。

支持平台:

Ascend

样例:

>>> from mindchemistry.e3.utils import radius_full
>>> from mindspore import ops, Tensor
>>> x = Tensor(ops.ones((5, 12, 3)))
>>> edge_index, batch_x, batch_y = radius_full(x, x)
>>> print(edge_index.shape)
(2, 720)
>>> print(batch_x.shape)
(60,)
>>> print(batch_y.shape)
(60,)