mindchemistry.cell.orb.Orb

View Source On Gitee
class mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None)[source]

Orb graph regressor. Combines a pretrained base model (e.g., MoleculeGNS) with optional node, graph, and stress regression heads, supporting fine-tuning or feature extraction workflows.

Parameters
  • model (MoleculeGNS) – Pretrained or randomly initialized base model for message passing and feature extraction.

  • node_head (NodeHead, optional) – Regression head for node-level property prediction. Default: None.

  • graph_head (GraphHead, optional) – Regression head for graph-level property prediction (e.g., energy). Default: None.

  • stress_head (GraphHead, optional) – Regression head for stress prediction. Default: None.

  • model_requires_grad (bool, optional) – Whether to fine-tune the base model (True) or freeze its parameters (False). Default: True.

  • cutoff_layers (int, optional) – If provided, only use the first cutoff_layers message passing layers of the base model. Default: None.

Inputs:
  • edge_features (dict) - Edge feature dictionary (e.g., {"vectors": Tensor, "r": Tensor}).

  • node_features (dict) - Node feature dictionary (e.g., {"atomic_numbers": Tensor, …}).

  • senders (Tensor) - Sender node indices for each edge. Shape: \((n_{edges},)\).

  • receivers (Tensor) - Receiver node indices for each edge. Shape: \((n_{edges},)\).

  • n_node (Tensor) - Number of nodes for each graph in the batch. Shape: \((n_{graphs},)\).

Outputs:
  • output (dict) - Dictionary containing: - edges (dict) - Edge features after message passing, e.g., {…, "feat": Tensor}. - nodes (dict) - Node features after message passing, e.g., {…, "feat": Tensor}. - graph_pred (Tensor) - Graph-level predictions, e.g., energy. Shape: \((n_{graphs}, target\_property\_dim)\). - node_pred (Tensor) - Node-level predictions. Shape: \((n_{nodes}, target\_property\_dim)\). - stress_pred (Tensor) - Stress predictions (if stress_head is provided). Shape: \((n_{graphs}, 6)\).

Raises
  • ValueError – If neither node_head nor graph_head is provided.

  • ValueError – If cutoff_layers exceeds the number of message passing steps in the base model.

  • ValueError – If atomic_numbers is not provided when graph_head is required.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead
>>> Orb = Orb(
...     model=MoleculeGNS(
...         num_node_in_features=256,
...         num_node_out_features=3,
...         num_edge_in_features=23,
...         latent_dim=256,
...         interactions="simple_attention",
...         interaction_params={
...             "distance_cutoff": True,
...             "polynomial_order": 4,
...             "cutoff_rmax": 6,
...             "attention_gate": "sigmoid",
...         },
...         num_message_passing_steps=15,
...         num_mlp_layers=2,
...         mlp_hidden_dim=512,
...         use_embedding=True,
...         node_feature_names=["feat"],
...         edge_feature_names=["feat"],
...     ),
...     graph_head=EnergyHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=1,
...         node_aggregation="mean",
...         reference_energy_name="vasp-shifted",
...         train_reference=True,
...         predict_atom_avg=True,
...     ),
...     node_head=NodeHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=3,
...         remove_mean=True,
...     ),
...     stress_head=GraphHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=6,
...         compute_stress=True,
...     ),
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32))
... }
>>> edge_features = {
...     "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
...     "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> output = Orb(edge_features, node_features, senders, receivers, n_node)
>>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape)
(1, 1) (4, 3) (1, 6)
predict(edge_features, node_features, senders, receivers, n_node, atomic_numbers)[source]

Predict node and/or graph level attributes.

Parameters
  • edge_features – A dictionary, e.g., {"vectors": Tensor, "r": Tensor}.

  • node_features – A dictionary, e.g., {"atomic_numbers": Tensor, "positions": Tensor, "atomic_numbers_embedding": Tensor}.

  • senders – A tensor of shape (n_edges,) containing the sender node indices.

  • receivers – A tensor of shape (n_edges,) containing the receiver node indices.

  • n_node – A tensor of shape (1,) containing the number of nodes.

  • atomic_numbers – A tensor of atomic numbers for reference energy calculation.

Returns

A dictionary containing the predictions: - graph_pred: Graph-level predictions (e.g., energy) of shape (n_graphs, graph_property_dim). - stress_pred: Stress predictions (if stress_head is provided) of shape (n_graphs, stress_dim). - node_pred: Node-level predictions of shape (n_nodes, node_property_dim).

Return type

ouput_dict