mindchemistry.cell.orb.Orb
- class mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None)[源代码]
Orb图回归器。将预训练的基础模型(如MoleculeGNS)与可选的节点、图和应力回归头结合,支持微调或特征提取工作流程。
- 参数:
model (MoleculeGNS) - 用于消息传递和特征提取的预训练或随机初始化基础模型。
node_head (NodeHead,可选) - 节点级属性预测的回归头。默认值:
None
。graph_head (GraphHead,可选) - 图级属性预测(例如能量)的回归头。默认值:
None
。stress_head (GraphHead,可选) - 应力预测的回归头。默认值:
None
。model_requires_grad (bool,可选) - 是否微调基础模型(True)或冻结其参数(False)。默认值:
True
。cutoff_layers (int,可选) - 如果提供,仅使用基础模型的前
"cutoff_layers"
个消息传递层。默认值:None
。
- 输入:
edge_features (dict) - 边特征字典(例如,{"vectors": Tensor, "r": Tensor})。
node_features (dict) - 节点特征字典(例如,{"atomic_numbers": Tensor, …})。
senders (Tensor) - 每条边的发送节点索引。形状:\((n_{edges},)\)。
receivers (Tensor) - 每条边的接收节点索引。形状:\((n_{edges},)\)。
n_node (Tensor) - 批次中每个图的节点数量。形状:\((n_{graphs},)\)。
- 输出:
output (dict) - 包含以下内容的字典:
edges (dict) - 消息传递后的边特征,例如 {…, "feat": Tensor}。
nodes (dict) - 消息传递后的节点特征,例如 {…, "feat": Tensor}。
graph_pred (Tensor) - 图级预测,例如能量。形状:\((n_{graphs}, target\_property\_dim)\)。
node_pred (Tensor) - 节点级预测。形状:\((n_{nodes}, target\_property\_dim)\)。
stress_pred (Tensor) - 应力预测(如果提供stress_head)。形状:\((n_{graphs}, 6)\)。
- 异常:
ValueError - 如果既未提供node_head也未提供graph_head。
ValueError - 如果cutoff_layers超过基础模型中的消息传递步骤数。
ValueError - 如果graph_head需要时未提供atomic_numbers。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead >>> Orb = Orb( ... model=MoleculeGNS( ... num_node_in_features=256, ... num_node_out_features=3, ... num_edge_in_features=23, ... latent_dim=256, ... interactions="simple_attention", ... interaction_params={ ... "distance_cutoff": True, ... "polynomial_order": 4, ... "cutoff_rmax": 6, ... "attention_gate": "sigmoid", ... }, ... num_message_passing_steps=15, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... use_embedding=True, ... node_feature_names=["feat"], ... edge_feature_names=["feat"], ... ), ... graph_head=EnergyHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=1, ... node_aggregation="mean", ... reference_energy_name="vasp-shifted", ... train_reference=True, ... predict_atom_avg=True, ... ), ... node_head=NodeHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=3, ... remove_mean=True, ... ), ... stress_head=GraphHead( ... latent_dim=256, ... num_mlp_layers=1, ... mlp_hidden_dim=256, ... target_property_dim=6, ... compute_stress=True, ... ), ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> n_node = Tensor([n_atoms], mindspore.int32) >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> output = Orb(edge_features, node_features, senders, receivers, n_node) >>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape) (1, 1) (4, 3) (1, 6)