mindchemistry.cell.orb.EnergyHead
- class mindchemistry.cell.orb.EnergyHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, predict_atom_avg: bool = True, reference_energy_name: str = 'mp-traj-d3', train_reference: bool = False, dropout: Optional[float] = None, node_aggregation: Optional[str] = 'mean')[source]
Graph-level energy prediction head. Implements neural network head for predicting total energy or per-atom average energy of molecular graphs. Supports node-level aggregation, reference energy offset, and flexible output modes.
- Parameters
latent_dim (int) – Input feature dimension for each node.
num_mlp_layers (int) – Number of hidden layers in MLP.
mlp_hidden_dim (int) – Hidden dimension size of MLP.
target_property_dim (int) – Output dimension of energy property (typically 1).
predict_atom_avg (bool, optional) – Whether to predict per-atom average energy instead of total energy. Default:
True
.reference_energy_name (str, optional) – Reference energy name for offset, e.g.,
"vasp-shifted"
. Default:"mp-traj-d3"
.train_reference (bool, optional) – Whether to train reference energy as learnable parameter. Default:
False
.dropout (Optional[float], optional) – Dropout rate for MLP. Default:
None
.node_aggregation (str, optional) – Aggregation method for node predictions, e.g.,
"mean"
or"sum"
. Default:None
.
- Inputs:
node_features (dict) - Node feature dictionary, must contain key "feat" with shape \((n_{nodes}, latent\_dim)\).
n_node (Tensor) - Number of nodes in graph, shape \((1,)\).
- Outputs:
output (dict) - Dictionary containing key "graph_pred" with value of shape \((1, target\_property\_dim)\).
- Raises
ValueError – If required feature keys are missing in node_features.
ValueError – If node_aggregation is not a supported type.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb.gns import EnergyHead >>> energy_head = EnergyHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=1, ... node_aggregation="mean", ... reference_energy_name="vasp-shifted", ... train_reference=True, ... predict_atom_avg=True, ... ) >>> n_atoms = 4 >>> n_node = Tensor([n_atoms], mindspore.int32) >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) ... } >>> output = energy_head(node_features, n_node) >>> print(output['graph_pred'].shape) (1, 1)