mindchemistry.cell.orb.MoleculeGNS
- class mindchemistry.cell.orb.MoleculeGNS(num_node_in_features, num_node_out_features, num_edge_in_features, latent_dim, num_message_passing_steps, num_mlp_layers, mlp_hidden_dim, node_feature_names, edge_feature_names, use_embedding=True, interactions='simple_attention', interaction_params=None)[source]
Molecular graph neural network. Implements flexible modular graph neural network for molecular property prediction based on message passing with attention or other interaction mechanisms. Supports node and edge embeddings, multiple message passing steps, and customizable interaction layers for complex molecular graphs.
- Parameters
num_node_in_features (int) – Number of input features per node.
num_node_out_features (int) – Number of output features per node.
num_edge_in_features (int) – Number of input features per edge.
latent_dim (int) – Latent dimension for node and edge representations.
num_message_passing_steps (int) – Number of message passing layers.
num_mlp_layers (int) – Number of hidden layers in node and edge update MLPs.
mlp_hidden_dim (int) – Hidden dimension size of MLPs.
node_feature_names (List[str]) – List of node feature keys to use from input dictionary.
edge_feature_names (List[str]) – List of edge feature keys to use from input dictionary.
use_embedding (bool, optional) – Whether to use atomic number embedding for nodes. Default:
True
.interactions (str, optional) – Type of interaction layer to use (e.g.,
"simple_attention"
). Default:"simple_attention"
.interaction_params (Optional[Dict[str, Any]], optional) – Parameters for interaction layer, e.g., cutoff, polynomial order, gate type. Default:
None
.
- Inputs:
edge_features (dict) - Edge feature dictionary, must contain keys specified in edge_feature_names.
node_features (dict) - Node feature dictionary, must contain keys specified in node_feature_names.
senders (Tensor) - Sender node indices for each edge, shape \((n_{edges},)\).
receivers (Tensor) - Receiver node indices for each edge, shape \((n_{edges},)\).
- Outputs:
edges (dict) - Updated edge feature dictionary with key "feat" of shape \((n_{edges}, latent\_dim)\).
nodes (dict) - Updated node feature dictionary with key "feat" of shape \((n_{nodes}, latent\_dim)\).
- Raises
ValueError – If required feature keys are missing in edge_features or node_features.
ValueError – If interactions is not a supported type.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb.gns import MoleculeGNS >>> gns_model = MoleculeGNS( ... num_node_in_features=256, ... num_node_out_features=3, ... num_edge_in_features=23, ... latent_dim=256, ... interactions="simple_attention", ... interaction_params={ ... "distance_cutoff": True, ... "polynomial_order": 4, ... "cutoff_rmax": 6, ... "attention_gate": "sigmoid", ... }, ... num_message_passing_steps=15, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... use_embedding=True, ... node_feature_names=["feat"], ... edge_feature_names=["feat"], ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> edges, nodes = gns_model( ... edge_features, ... node_features, ... senders, ... receivers, ... ) >>> print(edges["feat"].shape, nodes["feat"].shape) (10, 256) (4, 256)