mindscience.models.GraphCast.GraphCastNet
- class mindscience.models.GraphCast.GraphCastNet(vg_in_channels, vg_out_channels, vm_in_channels, em_in_channels, eg2m_in_channels, em2g_in_channels, latent_dims, processing_steps, g2m_src_idx, g2m_dst_idx, m2m_src_idx, m2m_dst_idx, m2g_src_idx, m2g_dst_idx, mesh_node_feats, mesh_edge_feats, g2m_edge_feats, m2g_edge_feats, per_variable_level_mean, per_variable_level_std, recompute=False)[源代码]
GraphCast 基于图神经网络和新颖的高分辨率多尺度网格表示的自回归模型。 详情请参阅 GraphCast: Learning skillful medium-range global weather forecasting。
- 参数:
vg_in_channels (int) - grid节点维度。
vg_out_channels (int) - grid节点最终维度。
vm_in_channels (int) - mesh节点维度。
em_in_channels (int) - mesh边维度。
eg2m_in_channels (int) - grid到mesh边维度。
em2g_in_channels (int) - mesh到grid边维度。
latent_dims (int) - 隐藏层的维度数。
processing_steps (int) - 处理步骤数。
g2m_src_idx (Tensor) - grid到mesh边的源节点索引。
g2m_dst_idx (Tensor) - grid到mesh边的目标节点索引。
m2m_src_idx (Tensor) - mesh到mesh边的源节点索引。
m2m_dst_idx (Tensor) - mesh到mesh边的目标节点索引。
m2g_src_idx (Tensor) - mesh到grid边的源节点索引。
m2g_dst_idx (Tensor) - mesh到grid边的目标节点索引。
mesh_node_feats (Tensor) - mesh节点特征。
mesh_edge_feats (Tensor) - mesh边特征。
g2m_edge_feats (Tensor) - grid到mesh边特征。
m2g_edge_feats (Tensor) - mesh到grid边特征。
per_variable_level_mean (Tensor) - 时间差分的每个变量级别的反方差均值。
per_variable_level_std (Tensor) - 时间差分的每个变量级别的反方差标准差。
recompute (bool, 可选) - 确定是否重新计算。默认值:
False。
- 输入:
input (Tensor) - 形状为 \((batch\_size, height\_size * width\_size, feature\_size)\) 的张量。
- 输出:
output (Tensor) - 形状为 \((height\_size * width\_size, feature\_size)\) 的张量。
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, Tensor >>> from mindscience.models.GraphCast.graphcastnet import GraphCastNet >>> >>> mesh_node_num = 2562 >>> grid_node_num = 32768 >>> mesh_edge_num = 20460 >>> g2m_edge_num = 50184 >>> m2g_edge_num = 98304 >>> vm_in_channels = 3 >>> em_in_channels = 4 >>> eg2m_in_channels = 4 >>> em2g_in_channels = 4 >>> feature_num = 69 >>> g2m_src_idx = Tensor(np.random.randint(0, grid_node_num, size=[g2m_edge_num]), ms.int32) >>> g2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[g2m_edge_num]), ms.int32) >>> m2m_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2g_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[m2g_edge_num]), ms.int32) >>> m2g_dst_idx = Tensor(np.random.randint(0, grid_node_num, size=[m2g_edge_num]), ms.int32) >>> mesh_node_feats = Tensor(np.random.rand(mesh_node_num, vm_in_channels).astype(np.float32), ms.float32) >>> mesh_edge_feats = Tensor(np.random.rand(mesh_edge_num, em_in_channels).astype(np.float32), ms.float32) >>> g2m_edge_feats = Tensor(np.random.rand(g2m_edge_num, eg2m_in_channels).astype(np.float32), ms.float32) >>> m2g_edge_feats = Tensor(np.random.rand(m2g_edge_num, em2g_in_channels).astype(np.float32), ms.float32) >>> per_variable_level_mean = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> per_variable_level_std = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> grid_node_feats = Tensor(np.random.rand(grid_node_num, feature_num).astype(np.float32), ms.float32) >>> graphcast_model = GraphCastNet(vg_in_channels=feature_num, ... vg_out_channels=feature_num, ... vm_in_channels=vm_in_channels, ... em_in_channels=em_in_channels, ... eg2m_in_channels=eg2m_in_channels, ... em2g_in_channels=em2g_in_channels, ... latent_dims=512, ... processing_steps=4, ... g2m_src_idx=g2m_src_idx, ... g2m_dst_idx=g2m_dst_idx, ... m2m_src_idx=m2m_src_idx, ... m2m_dst_idx=m2m_dst_idx, ... m2g_src_idx=m2g_src_idx, ... m2g_dst_idx=m2g_dst_idx, ... mesh_node_feats=mesh_node_feats, ... mesh_edge_feats=mesh_edge_feats, ... g2m_edge_feats=g2m_edge_feats, ... m2g_edge_feats=m2g_edge_feats, ... per_variable_level_mean=per_variable_level_mean, ... per_variable_level_std=per_variable_level_std) >>> out = graphcast_model(Tensor(grid_node_feats, ms.float32)) >>> print(out.shape) (32768, 69)