mindchemistry.cell.orb.AttentionInteractionNetwork

查看源文件
class mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: str = 'sigmoid', distance_cutoff: bool = True, polynomial_order: int = 4, cutoff_rmax: float = 6.0)[源代码]

注意力交互网络。实现基于注意力机制的消息传递神经网络层,用于分子图的边更新。

参数:
  • num_node_in (int) - 节点输入特征数量。

  • num_node_out (int) - 节点输出特征数量。

  • num_edge_in (int) - 边输入特征数量。

  • num_edge_out (int) - 边输出特征数量。

  • num_mlp_layers (int) - 节点和边更新MLP的隐藏层数量。

  • mlp_hidden_dim (int) - MLP的隐藏维度大小。

  • attention_gate (str,可选) - 注意力门类型, "sigmoid""softmax"。默认值: "sigmoid"

  • distance_cutoff (bool,可选) - 是否使用基于距离的边截断。默认值: True

  • polynomial_order (int,可选) - 多项式截断函数的阶数。默认值: 4

  • cutoff_rmax (float,可选) - 截断的最大距离。默认值: 6.0

输入:
  • graph_edges (dict) - 边特征字典,必须包含键"feat",形状为 \((n_{edges}, num\_edge\_in)\)

  • graph_nodes (dict) - 节点特征字典,必须包含键"feat",形状为 \((n_{nodes}, num\_node\_in)\)

  • senders (Tensor) - 每条边的发送节点索引,形状为 \((n_{edges},)\)

  • receivers (Tensor) - 每条边的接收节点索引,形状为 \((n_{edges},)\)

输出:
  • edges (dict) - 更新的边特征字典,键"feat"的形状为 \((n_{edges}, num\_edge\_out)\)

  • nodes (dict) - 更新的节点特征字典,键"feat"的形状为 \((n_{nodes}, num\_node\_out)\)

异常:
  • ValueError - 如果 attention_gate 不是"sigmoid"或"softmax"。

  • ValueError - 如果边或节点特征不包含必需的"feat"键。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork
>>> attn_net = AttentionInteractionNetwork(
...     num_node_in=256,
...     num_node_out=256,
...     num_edge_in=256,
...     num_edge_out=256,
...     num_mlp_layers=2,
...     mlp_hidden_dim=512,
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
...     "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> edge_features = {
...     "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
...     "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)),
...     "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> edges, nodes = attn_net(
...     edge_features,
...     node_features,
...     senders,
...     receivers,
... )
>>> print(edges["feat"].shape, nodes["feat"].shape)
(10, 256) (4, 256)