mindchemistry.cell.orb.MoleculeGNS
- class mindchemistry.cell.orb.MoleculeGNS(num_node_in_features: int, num_node_out_features: int, num_edge_in_features: int, latent_dim: int, num_message_passing_steps: int, num_mlp_layers: int, mlp_hidden_dim: int, node_feature_names: List[str], edge_feature_names: List[str], use_embedding: bool = True, interactions: str = 'simple_attention', interaction_params: Optional[Dict[str, Any]] = None)[源代码]
分子图神经网络。实现用于分子性质预测的灵活模块化图神经网络,基于注意力或其他交互机制的消息传递。支持节点和边嵌入、多个消息传递步骤,以及用于复杂分子图的可定制交互层。
- 参数:
num_node_in_features (int) - 每个节点的输入特征数量。
num_node_out_features (int) - 每个节点的输出特征数量。
num_edge_in_features (int) - 每条边的输入特征数量。
latent_dim (int) - 节点和边表示的潜在维度。
num_message_passing_steps (int) - 消息传递层的数量。
num_mlp_layers (int) - 节点和边更新MLP的隐藏层数量。
mlp_hidden_dim (int) - MLP的隐藏维度大小。
node_feature_names (List[str]) - 从输入字典中使用的节点特征键列表。
edge_feature_names (List[str]) - 从输入字典中使用的边特征键列表。
use_embedding (bool,可选) - 是否对节点使用原子序数嵌入。默认值:
True
。interactions (str,可选) - 要使用的交互层类型(例如,
"simple_attention"
)。默认值:"simple_attention"
。interaction_params (Optional[Dict[str, Any]],可选) - 交互层的参数,例如截断、多项式阶数、门类型。默认值:
None
。
- 输入:
edge_features (dict) - 边特征字典,必须包含 edge_feature_names 中指定的键。
node_features (dict) - 节点特征字典,必须包含 node_feature_names 中指定的键。
senders (Tensor) - 每条边的发送节点索引,形状为 \((n_{edges},)\)。
receivers (Tensor) - 每条边的接收节点索引,形状为 \((n_{edges},)\)。
- 输出:
edges (dict) - 更新的边特征字典,键"feat"的形状为 \((n_{edges}, latent\_dim)\)。
nodes (dict) - 更新的节点特征字典,键"feat"的形状为 \((n_{nodes}, latent\_dim)\)。
- 异常:
ValueError - 如果 edge_features 或 node_features 中缺少必需的特征键。
ValueError - 如果 interactions 不是支持的类型。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindchemistry.cell.orb.gns import MoleculeGNS >>> gns_model = MoleculeGNS( ... num_node_in_features=256, ... num_node_out_features=3, ... num_edge_in_features=23, ... latent_dim=256, ... interactions="simple_attention", ... interaction_params={ ... "distance_cutoff": True, ... "polynomial_order": 4, ... "cutoff_rmax": 6, ... "attention_gate": "sigmoid", ... }, ... num_message_passing_steps=15, ... num_mlp_layers=2, ... mlp_hidden_dim=512, ... use_embedding=True, ... node_feature_names=["feat"], ... edge_feature_names=["feat"], ... ) >>> n_atoms = 4 >>> n_edges = 10 >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) >>> for i, num in enumerate(atomic_numbers.asnumpy()): ... atomic_numbers_embedding_np[i, num - 1] = 1.0 >>> node_features = { ... "atomic_numbers": atomic_numbers, ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) ... } >>> edge_features = { ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) ... } >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) >>> edges, nodes = gns_model( ... edge_features, ... node_features, ... senders, ... receivers, ... ) >>> print(edges["feat"].shape, nodes["feat"].shape) (10, 256) (4, 256)