mindchemistry.e3.utils.radius

查看源文件
mindchemistry.e3.utils.radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32)[源代码]

x 中找到每个 y 元素在距离 r 内的所有点。

参数:
  • x (ndarray) - x 节点特征矩阵。

  • y (ndarray) - y 节点特征矩阵。

  • r (ndarray, float) - 半径。

  • batch_x (ndarray) - x 批向量。如果为 None,则根据 x 计算并返回。默认值:None

  • batch_y (ndarray) - y 批向量。如果为 None,则根据 y 计算并返回。默认值:None

  • max_num_neighbors (int) - 返回每个 y 元素的最大邻居数量。默认值:32

返回:
  • edge_index (numpy.ndarray) - 包括边的起点与终点。

  • batch_x (numpy.ndarray) - x 批向量。

  • batch_y (numpy.ndarray) - y 批向量。

异常:
  • ValueError - 如果 xy 的最后一个维度不匹配。

支持平台:

Ascend

样例:

>>> from mindchemistry.e3.utils import radius
>>> import numpy as np
>>> np.random.seed(1)
>>> x = np.random.random((5, 12, 3))
>>> r = 0.5
>>> edge_index, batch_x, batch_y = radius(x, x, r)
>>> print(edge_index.shape)
(2, 222)
>>> print(batch_x.shape)
(60,)
>>> print(batch_y.shape)
(60,)