mindchemistry.cell.orb.AttentionInteractionNetwork

View Source On Gitee
class mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: Literal['sigmoid', 'softmax'] = 'sigmoid', distance_cutoff: bool = True, polynomial_order: Optional[int] = 4, cutoff_rmax: Optional[float] = 6.0)[source]

Attention interaction network. Implements attention-based message passing neural network layer for edge updates in molecular graphs.

Parameters
  • num_node_in (int) – Number of input node features.

  • num_node_out (int) – Number of output node features.

  • num_edge_in (int) – Number of input edge features.

  • num_edge_out (int) – Number of output edge features.

  • num_mlp_layers (int) – Number of hidden layers in node and edge update MLPs.

  • mlp_hidden_dim (int) – Hidden dimension size of MLPs.

  • attention_gate (str, optional) – Attention gate type, "sigmoid" or "softmax". Default: "sigmoid".

  • distance_cutoff (bool, optional) – Whether to use distance-based edge cutoff. Default: True.

  • polynomial_order (int, optional) – Order of polynomial cutoff function. Default: 4.

  • cutoff_rmax (float, optional) – Maximum distance for cutoff. Default: 6.0.

Inputs:
  • graph_edges (dict) - Edge feature dictionary, must contain key "feat" with shape \((n_{edges}, num\_edge\_in)\).

  • graph_nodes (dict) - Node feature dictionary, must contain key "feat" with shape \((n_{nodes}, num\_node\_in)\).

  • senders (Tensor) - Sender node indices for each edge, shape \((n_{edges},)\).

  • receivers (Tensor) - Receiver node indices for each edge, shape \((n_{edges},)\).

Outputs:
  • edges (dict) - Updated edge feature dictionary with key "feat" of shape \((n_{edges}, num\_edge\_out)\).

  • nodes (dict) - Updated node feature dictionary with key "feat" of shape \((n_{nodes}, num\_node\_out)\).

Raises
  • ValueError – If attention_gate is not "sigmoid" or "softmax".

  • ValueError – If edge or node features do not contain the required "feat" key.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork
>>> attn_net = AttentionInteractionNetwork(
...     num_node_in=256,
...     num_node_out=256,
...     num_edge_in=256,
...     num_edge_out=256,
...     num_mlp_layers=2,
...     mlp_hidden_dim=512,
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
...     "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> edge_features = {
...     "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
...     "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)),
...     "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> edges, nodes = attn_net(
...     edge_features,
...     node_features,
...     senders,
...     receivers,
... )
>>> print(edges["feat"].shape, nodes["feat"].shape)
(10, 256) (4, 256)