mindchemistry.cell.orb.MoleculeGNS

查看源文件
class mindchemistry.cell.orb.MoleculeGNS(num_node_in_features: int, num_node_out_features: int, num_edge_in_features: int, latent_dim: int, num_message_passing_steps: int, num_mlp_layers: int, mlp_hidden_dim: int, node_feature_names: List[str], edge_feature_names: List[str], use_embedding: bool = True, interactions: str = 'simple_attention', interaction_params: Optional[Dict[str, Any]] = None)[源代码]

分子图神经网络。实现用于分子性质预测的灵活模块化图神经网络,基于注意力或其他交互机制的消息传递。支持节点和边嵌入、多个消息传递步骤,以及用于复杂分子图的可定制交互层。

参数:
  • num_node_in_features (int) - 每个节点的输入特征数量。

  • num_node_out_features (int) - 每个节点的输出特征数量。

  • num_edge_in_features (int) - 每条边的输入特征数量。

  • latent_dim (int) - 节点和边表示的潜在维度。

  • num_message_passing_steps (int) - 消息传递层的数量。

  • num_mlp_layers (int) - 节点和边更新MLP的隐藏层数量。

  • mlp_hidden_dim (int) - MLP的隐藏维度大小。

  • node_feature_names (List[str]) - 从输入字典中使用的节点特征键列表。

  • edge_feature_names (List[str]) - 从输入字典中使用的边特征键列表。

  • use_embedding (bool,可选) - 是否对节点使用原子序数嵌入。默认值: True

  • interactions (str,可选) - 要使用的交互层类型(例如, "simple_attention")。默认值: "simple_attention"

  • interaction_params (Optional[Dict[str, Any]],可选) - 交互层的参数,例如截断、多项式阶数、门类型。默认值: None

输入:
  • edge_features (dict) - 边特征字典,必须包含 edge_feature_names 中指定的键。

  • node_features (dict) - 节点特征字典,必须包含 node_feature_names 中指定的键。

  • senders (Tensor) - 每条边的发送节点索引,形状为 \((n_{edges},)\)

  • receivers (Tensor) - 每条边的接收节点索引,形状为 \((n_{edges},)\)

输出:
  • edges (dict) - 更新的边特征字典,键"feat"的形状为 \((n_{edges}, latent\_dim)\)

  • nodes (dict) - 更新的节点特征字典,键"feat"的形状为 \((n_{nodes}, latent\_dim)\)

异常:
  • ValueError - 如果 edge_featuresnode_features 中缺少必需的特征键。

  • ValueError - 如果 interactions 不是支持的类型。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import MoleculeGNS
>>> gns_model = MoleculeGNS(
...     num_node_in_features=256,
...     num_node_out_features=3,
...     num_edge_in_features=23,
...     latent_dim=256,
...     interactions="simple_attention",
...     interaction_params={
...         "distance_cutoff": True,
...         "polynomial_order": 4,
...         "cutoff_rmax": 6,
...         "attention_gate": "sigmoid",
...     },
...     num_message_passing_steps=15,
...     num_mlp_layers=2,
...     mlp_hidden_dim=512,
...     use_embedding=True,
...     node_feature_names=["feat"],
...     edge_feature_names=["feat"],
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
...     atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
...     "atomic_numbers": atomic_numbers,
...     "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
...     "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
...     "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> edge_features = {
...     "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
...     "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)),
...     "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> edges, nodes = gns_model(
...     edge_features,
...     node_features,
...     senders,
...     receivers,
... )
>>> print(edges["feat"].shape, nodes["feat"].shape)
(10, 256) (4, 256)