mindscience.models.GraphCast.GraphCastNet
- class mindscience.models.GraphCast.GraphCastNet(vg_in_channels, vg_out_channels, vm_in_channels, em_in_channels, eg2m_in_channels, em2g_in_channels, latent_dims, processing_steps, g2m_src_idx, g2m_dst_idx, m2m_src_idx, m2m_dst_idx, m2g_src_idx, m2g_dst_idx, mesh_node_feats, mesh_edge_feats, g2m_edge_feats, m2g_edge_feats, per_variable_level_mean, per_variable_level_std, recompute=False)[source]
The GraphCast is based on graph neural networks and a novel high-resolution multi-scale mesh representation autoregressive model. The details can be found in GraphCast: Learning skillful medium-range global weather forecasting.
- Parameters
vg_in_channels (int) – The grid node dimensions.
vg_out_channels (int) – The grid node final dimensions.
vm_in_channels (int) – The mesh node dimensions.
em_in_channels (int) – The mesh edge dimensions.
eg2m_in_channels (int) – The grid to mesh edge dimensions.
em2g_in_channels (int) – The mesh to grid edge dimensions.
latent_dims (int) – The number of dims of hidden layers.
processing_steps (int) – The number of processing steps.
g2m_src_idx (Tensor) – The source node index of grid to mesh edges.
g2m_dst_idx (Tensor) – The destination node index of grid to mesh edges.
m2m_src_idx (Tensor) – The source node index of mesh to mesh edges.
m2m_dst_idx (Tensor) – The destination node index of mesh to mesh edges.
m2g_src_idx (Tensor) – The source node index of mesh to grid edges.
m2g_dst_idx (Tensor) – The destination node index of mesh to grid edges.
mesh_node_feats (Tensor) – The features of mesh nodes.
mesh_edge_feats (Tensor) – The features of mesh edges.
g2m_edge_feats (Tensor) – The features of grid to mesh edges.
m2g_edge_feats (Tensor) – The features of mesh to grid edges.
per_variable_level_mean (Tensor) – The mean of the per-variable-level inverse variance of time differences.
per_variable_level_std (Tensor) – The standard deviation of the per-variable-level inverse variance of time differences.
recompute (bool, optional) – Determine whether to recompute. Default:
False.
- Inputs:
input (Tensor) - Tensor of shape \((batch\_size, height\_size * width\_size, feature\_size)\) .
- Outputs:
output (Tensor) - Tensor of shape \((height\_size * width\_size, feature\_size)\) .
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, Tensor >>> from mindscience.models.GraphCast.graphcastnet import GraphCastNet >>> >>> mesh_node_num = 2562 >>> grid_node_num = 32768 >>> mesh_edge_num = 20460 >>> g2m_edge_num = 50184 >>> m2g_edge_num = 98304 >>> vm_in_channels = 3 >>> em_in_channels = 4 >>> eg2m_in_channels = 4 >>> em2g_in_channels = 4 >>> feature_num = 69 >>> g2m_src_idx = Tensor(np.random.randint(0, grid_node_num, size=[g2m_edge_num]), ms.int32) >>> g2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[g2m_edge_num]), ms.int32) >>> m2m_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2g_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[m2g_edge_num]), ms.int32) >>> m2g_dst_idx = Tensor(np.random.randint(0, grid_node_num, size=[m2g_edge_num]), ms.int32) >>> mesh_node_feats = Tensor(np.random.rand(mesh_node_num, vm_in_channels).astype(np.float32), ms.float32) >>> mesh_edge_feats = Tensor(np.random.rand(mesh_edge_num, em_in_channels).astype(np.float32), ms.float32) >>> g2m_edge_feats = Tensor(np.random.rand(g2m_edge_num, eg2m_in_channels).astype(np.float32), ms.float32) >>> m2g_edge_feats = Tensor(np.random.rand(m2g_edge_num, em2g_in_channels).astype(np.float32), ms.float32) >>> per_variable_level_mean = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> per_variable_level_std = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> grid_node_feats = Tensor(np.random.rand(grid_node_num, feature_num).astype(np.float32), ms.float32) >>> graphcast_model = GraphCastNet(vg_in_channels=feature_num, ... vg_out_channels=feature_num, ... vm_in_channels=vm_in_channels, ... em_in_channels=em_in_channels, ... eg2m_in_channels=eg2m_in_channels, ... em2g_in_channels=em2g_in_channels, ... latent_dims=512, ... processing_steps=4, ... g2m_src_idx=g2m_src_idx, ... g2m_dst_idx=g2m_dst_idx, ... m2m_src_idx=m2m_src_idx, ... m2m_dst_idx=m2m_dst_idx, ... m2g_src_idx=m2g_src_idx, ... m2g_dst_idx=m2g_dst_idx, ... mesh_node_feats=mesh_node_feats, ... mesh_edge_feats=mesh_edge_feats, ... g2m_edge_feats=g2m_edge_feats, ... m2g_edge_feats=m2g_edge_feats, ... per_variable_level_mean=per_variable_level_mean, ... per_variable_level_std=per_variable_level_std) >>> out = graphcast_model(Tensor(grid_node_feats, ms.float32)) >>> print(out.shape) (32768, 69)