mindchemistry.cell.orb.AttentionInteractionNetwork
- class mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: str = 'sigmoid', distance_cutoff: bool = True, polynomial_order: int = 4, cutoff_rmax: float = 6.0)[源代码]
注意力交互网络。实现基于注意力机制的消息传递神经网络层,用于分子图的边更新。
- 参数:
num_node_in (int) - 节点输入特征数量。
num_node_out (int) - 节点输出特征数量。
num_edge_in (int) - 边输入特征数量。
num_edge_out (int) - 边输出特征数量。
num_mlp_layers (int) - 节点和边更新MLP的隐藏层数量。
mlp_hidden_dim (int) - MLP的隐藏维度大小。
attention_gate (str,可选) - 注意力门类型,
"sigmoid"
或"softmax"
。默认值:"sigmoid"
。distance_cutoff (bool,可选) - 是否使用基于距离的边截断。默认值:
True
。polynomial_order (int,可选) - 多项式截断函数的阶数。默认值:
4
。cutoff_rmax (float,可选) - 截断的最大距离。默认值:
6.0
。
- 输入:
graph_edges (dict) - 边特征字典,必须包含键"feat",形状为 \((n_{edges}, num\_edge\_in)\)。
graph_nodes (dict) - 节点特征字典,必须包含键"feat",形状为 \((n_{nodes}, num\_node\_in)\)。
senders (Tensor) - 每条边的发送节点索引,形状为 \((n_{edges},)\)。
receivers (Tensor) - 每条边的接收节点索引,形状为 \((n_{edges},)\)。
- 输出:
edges (dict) - 更新的边特征字典,键"feat"的形状为 \((n_{edges}, num\_edge\_out)\)。
nodes (dict) - 更新的节点特征字典,键"feat"的形状为 \((n_{nodes}, num\_node\_out)\)。
- 异常:
ValueError - 如果 attention_gate 不是"sigmoid"或"softmax"。
ValueError - 如果边或节点特征不包含必需的"feat"键。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork >>> attn_net = AttentionInteractionNetwork( ... num_node_in=256, ... num_node_out=256, ... num_edge_in=256, ... num_edge_out=256, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> edges, nodes = attn_net( ... edge_features, ... node_features, ... senders, ... receivers, ... ) >>> print(edges["feat"].shape, nodes["feat"].shape) (10, 256) (4, 256)