mindchemistry.cell.orb.AttentionInteractionNetwork
- class mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: Literal['sigmoid', 'softmax'] = 'sigmoid', distance_cutoff: bool = True, polynomial_order: Optional[int] = 4, cutoff_rmax: Optional[float] = 6.0)[source]
Attention interaction network. Implements attention-based message passing neural network layer for edge updates in molecular graphs.
- Parameters
num_node_in (int) – Number of input node features.
num_node_out (int) – Number of output node features.
num_edge_in (int) – Number of input edge features.
num_edge_out (int) – Number of output edge features.
num_mlp_layers (int) – Number of hidden layers in node and edge update MLPs.
mlp_hidden_dim (int) – Hidden dimension size of MLPs.
attention_gate (str, optional) – Attention gate type,
"sigmoid"
or"softmax"
. Default:"sigmoid"
.distance_cutoff (bool, optional) – Whether to use distance-based edge cutoff. Default:
True
.polynomial_order (int, optional) – Order of polynomial cutoff function. Default:
4
.cutoff_rmax (float, optional) – Maximum distance for cutoff. Default:
6.0
.
- Inputs:
graph_edges (dict) - Edge feature dictionary, must contain key "feat" with shape \((n_{edges}, num\_edge\_in)\).
graph_nodes (dict) - Node feature dictionary, must contain key "feat" with shape \((n_{nodes}, num\_node\_in)\).
senders (Tensor) - Sender node indices for each edge, shape \((n_{edges},)\).
receivers (Tensor) - Receiver node indices for each edge, shape \((n_{edges},)\).
- Outputs:
edges (dict) - Updated edge feature dictionary with key "feat" of shape \((n_{edges}, num\_edge\_out)\).
nodes (dict) - Updated node feature dictionary with key "feat" of shape \((n_{nodes}, num\_node\_out)\).
- Raises
ValueError – If attention_gate is not "sigmoid" or "softmax".
ValueError – If edge or node features do not contain the required "feat" key.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork >>> attn_net = AttentionInteractionNetwork( ... num_node_in=256, ... num_node_out=256, ... num_edge_in=256, ... num_edge_out=256, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> edges, nodes = attn_net( ... edge_features, ... node_features, ... senders, ... receivers, ... ) >>> print(edges["feat"].shape, nodes["feat"].shape) (10, 256) (4, 256)