mindchemistry.cell.orb.Orb

查看源文件
class mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None)[源代码]

Orb图回归器。将预训练的基础模型(如MoleculeGNS)与可选的节点、图和应力回归头结合,支持微调或特征提取工作流程。

参数:
  • model (MoleculeGNS) - 用于消息传递和特征提取的预训练或随机初始化基础模型。

  • node_head (NodeHead,可选) - 节点级属性预测的回归头。默认值: None

  • graph_head (GraphHead,可选) - 图级属性预测(例如能量)的回归头。默认值: None

  • stress_head (GraphHead,可选) - 应力预测的回归头。默认值: None

  • model_requires_grad (bool,可选) - 是否微调基础模型(True)或冻结其参数(False)。默认值: True

  • cutoff_layers (int,可选) - 如果提供,仅使用基础模型的前 "cutoff_layers" 个消息传递层。默认值: None

输入:
  • edge_features (dict) - 边特征字典(例如,{"vectors": Tensor, "r": Tensor})。

  • node_features (dict) - 节点特征字典(例如,{"atomic_numbers": Tensor, …})。

  • senders (Tensor) - 每条边的发送节点索引。形状:\((n_{edges},)\)

  • receivers (Tensor) - 每条边的接收节点索引。形状:\((n_{edges},)\)

  • n_node (Tensor) - 批次中每个图的节点数量。形状:\((n_{graphs},)\)

输出:
  • output (dict) - 包含以下内容的字典:

    • edges (dict) - 消息传递后的边特征,例如 {…, "feat": Tensor}

    • nodes (dict) - 消息传递后的节点特征,例如 {…, "feat": Tensor}

    • graph_pred (Tensor) - 图级预测,例如能量。形状:\((n_{graphs}, target\_property\_dim)\)

    • node_pred (Tensor) - 节点级预测。形状:\((n_{nodes}, target\_property\_dim)\)

    • stress_pred (Tensor) - 应力预测(如果提供stress_head)。形状:\((n_{graphs}, 6)\)

异常:
  • ValueError - 如果既未提供node_head也未提供graph_head。

  • ValueError - 如果cutoff_layers超过基础模型中的消息传递步骤数。

  • ValueError - 如果graph_head需要时未提供atomic_numbers。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead
>>> Orb = Orb(
...     model=MoleculeGNS(
...         num_node_in_features=256,
...         num_node_out_features=3,
...         num_edge_in_features=23,
...         latent_dim=256,
...         interactions="simple_attention",
...         interaction_params={
...             "distance_cutoff": True,
...             "polynomial_order": 4,
...             "cutoff_rmax": 6,
...             "attention_gate": "sigmoid",
...         },
...         num_message_passing_steps=15,
...         num_mlp_layers=2,
...         mlp_hidden_dim=512,
...         use_embedding=True,
...         node_feature_names=["feat"],
...         edge_feature_names=["feat"],
...     ),
...     graph_head=EnergyHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=1,
...         node_aggregation="mean",
...         reference_energy_name="vasp-shifted",
...         train_reference=True,
...         predict_atom_avg=True,
...     ),
...     node_head=NodeHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=3,
...         remove_mean=True,
...     ),
...     stress_head=GraphHead(
...         latent_dim=256,
...         num_mlp_layers=1,
...         mlp_hidden_dim=256,
...         target_property_dim=6,
...         compute_stress=True,
...     ),
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32))
... }
>>> edge_features = {
...     "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
...     "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> output = Orb(edge_features, node_features, senders, receivers, n_node)
>>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape)
(1, 1) (4, 3) (1, 6)