mindchemistry.cell.orb.Orb
- class mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None)[source]
Orb graph regressor. Combines a pretrained base model (e.g., MoleculeGNS) with optional node, graph, and stress regression heads, supporting fine-tuning or feature extraction workflows.
- Parameters
model (MoleculeGNS) – Pretrained or randomly initialized base model for message passing and feature extraction.
node_head (NodeHead, optional) – Regression head for node-level property prediction. Default:
None
.graph_head (GraphHead, optional) – Regression head for graph-level property prediction (e.g., energy). Default:
None
.stress_head (GraphHead, optional) – Regression head for stress prediction. Default:
None
.model_requires_grad (bool, optional) – Whether to fine-tune the base model (True) or freeze its parameters (False). Default:
True
.cutoff_layers (int, optional) – If provided, only use the first
cutoff_layers
message passing layers of the base model. Default:None
.
- Inputs:
edge_features (dict) - Edge feature dictionary (e.g., {"vectors": Tensor, "r": Tensor}).
node_features (dict) - Node feature dictionary (e.g., {"atomic_numbers": Tensor, …}).
senders (Tensor) - Sender node indices for each edge. Shape: \((n_{edges},)\).
receivers (Tensor) - Receiver node indices for each edge. Shape: \((n_{edges},)\).
n_node (Tensor) - Number of nodes for each graph in the batch. Shape: \((n_{graphs},)\).
- Outputs:
output (dict) - Dictionary containing: - edges (dict) - Edge features after message passing, e.g., {…, "feat": Tensor}. - nodes (dict) - Node features after message passing, e.g., {…, "feat": Tensor}. - graph_pred (Tensor) - Graph-level predictions, e.g., energy. Shape: \((n_{graphs}, target\_property\_dim)\). - node_pred (Tensor) - Node-level predictions. Shape: \((n_{nodes}, target\_property\_dim)\). - stress_pred (Tensor) - Stress predictions (if stress_head is provided). Shape: \((n_{graphs}, 6)\).
- Raises
ValueError – If neither node_head nor graph_head is provided.
ValueError – If cutoff_layers exceeds the number of message passing steps in the base model.
ValueError – If atomic_numbers is not provided when graph_head is required.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead >>> Orb = Orb( ... model=MoleculeGNS( ... num_node_in_features=256, ... num_node_out_features=3, ... num_edge_in_features=23, ... latent_dim=256, ... interactions="simple_attention", ... interaction_params={ ... "distance_cutoff": True, ... "polynomial_order": 4, ... "cutoff_rmax": 6, ... "attention_gate": "sigmoid", ... }, ... num_message_passing_steps=15, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... use_embedding=True, ... node_feature_names=["feat"], ... edge_feature_names=["feat"], ... ), ... graph_head=EnergyHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=1, ... node_aggregation="mean", ... reference_energy_name="vasp-shifted", ... train_reference=True, ... predict_atom_avg=True, ... ), ... node_head=NodeHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=3, ... remove_mean=True, ... ), ... stress_head=GraphHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=6, ... compute_stress=True, ... ), ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> n_node = Tensor([n_atoms], mindspore.int32) >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> output = Orb(edge_features, node_features, senders, receivers, n_node) >>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape) (1, 1) (4, 3) (1, 6)
- predict(edge_features, node_features, senders, receivers, n_node, atomic_numbers)[source]
Predict node and/or graph level attributes.
- Parameters
edge_features – A dictionary, e.g., {"vectors": Tensor, "r": Tensor}.
node_features – A dictionary, e.g., {"atomic_numbers": Tensor, "positions": Tensor, "atomic_numbers_embedding": Tensor}.
senders – A tensor of shape (n_edges,) containing the sender node indices.
receivers – A tensor of shape (n_edges,) containing the receiver node indices.
n_node – A tensor of shape (1,) containing the number of nodes.
atomic_numbers – A tensor of atomic numbers for reference energy calculation.
- Returns
A dictionary containing the predictions: - graph_pred: Graph-level predictions (e.g., energy) of shape (n_graphs, graph_property_dim). - stress_pred: Stress predictions (if stress_head is provided) of shape (n_graphs, stress_dim). - node_pred: Node-level predictions of shape (n_nodes, node_property_dim).
- Return type
ouput_dict