mindchemistry.e3.utils.radius_graph

查看源文件
mindchemistry.e3.utils.radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target')[源代码]

计算给定距离内图所有点之间的边。

参数:
  • x (ndarray) - 节点特征矩阵。

  • r (ndarray, float) - 半径。

  • batch (Tensor) - 批向量。如果为 None,则计算并返回。默认值:None

  • loop (bool) - 图中是否包含自环。默认值:False

  • max_num_neighbors (int) - 返回每个 y 元素的最大邻居数量。默认值:32

  • flow (str) - {'source_to_target', 'target_to_source'},与消息传递结合使用时的流向。默认值:'source_to_target'

返回:
  • edge_index (ndarray) - 包括边的起点与终点。

  • batch (ndarray) - 批向量。

异常:
  • ValueError - 如果 flow 不是 {'source_to_target', 'target_to_source'} 之一。

支持平台:

Ascend

样例:

>>> from mindchemistry.e3.utils import radius_graph
>>> import numpy as np
>>> np.random.seed(1)
>>> x = np.random.random((5, 12, 3))
>>> r = 0.5
>>> edge_index, batch = radius_graph(x, r)
>>> print(edge_index.shape)
(2, 162)
>>> print(batch.shape)
(60,)