mindchemistry.cell.orb.EnergyHead

View Source On Gitee
class mindchemistry.cell.orb.EnergyHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, predict_atom_avg: bool = True, reference_energy_name: str = 'mp-traj-d3', train_reference: bool = False, dropout: Optional[float] = None, node_aggregation: Optional[str] = 'mean')[source]

Graph-level energy prediction head. Implements neural network head for predicting total energy or per-atom average energy of molecular graphs. Supports node-level aggregation, reference energy offset, and flexible output modes.

Parameters
  • latent_dim (int) – Input feature dimension for each node.

  • num_mlp_layers (int) – Number of hidden layers in MLP.

  • mlp_hidden_dim (int) – Hidden dimension size of MLP.

  • target_property_dim (int) – Output dimension of energy property (typically 1).

  • predict_atom_avg (bool, optional) – Whether to predict per-atom average energy instead of total energy. Default: True.

  • reference_energy_name (str, optional) – Reference energy name for offset, e.g., "vasp-shifted". Default: "mp-traj-d3".

  • train_reference (bool, optional) – Whether to train reference energy as learnable parameter. Default: False.

  • dropout (Optional[float], optional) – Dropout rate for MLP. Default: None.

  • node_aggregation (str, optional) – Aggregation method for node predictions, e.g., "mean" or "sum". Default: None.

Inputs:
  • node_features (dict) - Node feature dictionary, must contain key "feat" with shape \((n_{nodes}, latent\_dim)\).

  • n_node (Tensor) - Number of nodes in graph, shape \((1,)\).

Outputs:
  • output (dict) - Dictionary containing key "graph_pred" with value of shape \((1, target\_property\_dim)\).

Raises
  • ValueError – If required feature keys are missing in node_features.

  • ValueError – If node_aggregation is not a supported type.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import EnergyHead
>>> energy_head = EnergyHead(
...     latent_dim=256,
...     num_mlp_layers=1,
...     mlp_hidden_dim=256,
...     target_property_dim=1,
...     node_aggregation="mean",
...     reference_energy_name="vasp-shifted",
...     train_reference=True,
...     predict_atom_avg=True,
... )
>>> n_atoms = 4
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
...     "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> output = energy_head(node_features, n_node)
>>> print(output['graph_pred'].shape)
(1, 1)
predict(node_features, n_node, atomic_numbers=None)[source]

Predict energy.

Parameters
  • node_features – Node features tensor

  • n_node – Number of nodes

  • atomic_numbers – Optional atomic numbers for reference energy calculation

Returns

Energy prediction

Return type

graph_pred