{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# NowcastNet: Physics-based Generative Model for Extreme Precipitation Nowcasting\n", "\n", "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/nowcasting/mindspore_Nowcastnet.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/nowcasting/mindspore_Nowcastnet.py) [![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindearth/docs/source_en/nowcasting/Nowcastnet.ipynb)\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Overview\n", "\n", "NowcastNet is a short-range precipitation forecasting model based on radar data developed by Long Mingsheng's team at Tsinghua University. It provides 0-3h short-range precipitation forecast results with a spatial resolution of about 1km. It includes a stochastic generative network and a deterministic evolution network. The evolution network yields physically plausible predictions for advective features at a scale of 20km. The generative network generates convective details present in the radar observations. More information can be found in [paper](https://www.nature.com/articles/s41586-023-06184-4). The architecture of the Nowcastnet is shown below.\n", "\n", "![nowcastnet](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindearth/docs/source_en/nowcasting/images/nowcastnet.png)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## NowcastNet\n", "\n", "1. Evolution network: a motion decoder for learning motion fields $v_{1:T}$ and an intensity decoder for learning intensity residuals $s_{1:T}$. Then, evolution operator uses $v_{1:T}$ and $s_{1:T}$ to predict $x_{1:T}^{''}$. The formalization is shown below.\n", "\n", "$$\n", "x_{1:T}^{''} = Evolution(x_{-T:0})\n", "$$\n", "\n", "2. Nowcast encoder & decoder: It uses [Semantic Image Synthesis with Spatially-Adaptive Normalization](https://openaccess.thecvf.com/content_CVPR_2019/papers/Park_Semantic_Image_Synthesis_With_Spatially-Adaptive_Normalization_CVPR_2019_paper.pdf) to model the relation between physics-conditioning $x_{1:T}^{''}$ and outputs." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Technology path\n", "\n", "MindSpore Earth solves the problem as follows:\n", "\n", "1. Data Construction.\n", "2. Model Construction.\n", "3. Loss function.\n", "4. Model Training.\n", "5. Model Evaluation and Visualization." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "Download the training and test dataset: [Nowcastnet/dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/nowcastnet/tiny_datasets/)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2024-01-30T12:21:07.849115100Z", "start_time": "2024-01-30T12:21:01.986652Z" }, "collapsed": false }, "outputs": [], "source": [ "import random\n", "\n", "import mindspore as ms\n", "import numpy as np\n", "from mindspore import context, nn, amp, set_seed, load_checkpoint, load_param_into_net" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2024-01-31T05:59:25.475709500Z", "start_time": "2024-01-31T05:59:25.432825200Z" }, "collapsed": false }, "outputs": [], "source": [ "from src import get_logger\n", "from src import EvolutionTrainer, GenerationTrainer, GenerateLoss, DiscriminatorLoss, EvolutionLoss\n", "from src import EvolutionPredictor, GenerationPredictor\n", "from src import RadarData, NowcastDataset\n", "from src.evolution import EvolutionNet\n", "from src.generator import GenerationNet\n", "from src.discriminator import TemporalDiscriminator\n", "from src.visual import plt_img\n", "from mindearth.utils.tools import load_yaml_config" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2024-01-30T12:21:12.488547200Z", "start_time": "2024-01-30T12:21:12.479570700Z" }, "collapsed": false }, "outputs": [], "source": [ "np.random.seed(0)\n", "set_seed(0)\n", "random.seed(0)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2024-01-30T12:21:14.948562900Z", "start_time": "2024-01-30T12:21:14.931578100Z" }, "collapsed": false }, "outputs": [], "source": [ "config = load_yaml_config(\"./configs/Nowcastnet.yaml\")\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=1)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Data Construction\n", "\n", "Download the statistic, training and validation dataset from [dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/nowcastnet/tiny_datasets/) to `./dataset`. Modify the parameter of `root_dir` in the [Nowcastnet.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindEarth/applications/nowcasting/Nowcastnet/configs/Nowcastnet.yaml), which set the directory for dataset.\n", "\n", "The `./dataset` is hosted with the following directory structure:\n", "\n", "```markdown\n", "├── train\n", "├── valid\n", "├── test\n", "```" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Evolution" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2024-01-30T12:21:28.111872100Z", "start_time": "2024-01-30T12:21:19.226508600Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2024-01-30 20:21:45,539 - utils.py[line:55] - INFO: {'name': 'NowcastNet', 'ngf': 32, 'pool_ensemble_num': 4, 'module_name': 'generation'}\n", "2024-01-30 20:21:45,539 - utils.py[line:55] - INFO: {'name': 'us', 'root_dir': './datasets', 't_in': 9, 't_out': 20, 'h_size': 512, 'w_size': 512, 'time_interval': 10, 'num_workers': 1, 'data_sink': False, 'batch_size': 1, 'noise_scale': 32}\n", "2024-01-30 20:21:45,540 - utils.py[line:55] - INFO: {'name': 'adam', 'beta1': 0.01, 'beta2': 0.9, 'g_lr': 1.5e-05, 'd_lr': '6e-5', 'epochs': 10}\n", "2024-01-30 20:21:45,541 - utils.py[line:55] - INFO: {'name': 'adam', 'lr': 0.001, 'weight_decay': 0.1, 'gamma': 0.5, 'epochs': 5}\n", "2024-01-30 20:21:45,541 - utils.py[line:55] - INFO: {'summary_dir': './summary/', 'eval_interval': 2, 'save_checkpoint_epochs': 2, 'keep_checkpoint_max': 4, 'key_info_timestep': [10, 60, 120], 'generate_ckpt_path': '', 'evolution_ckpt_path': '', 'visual': True, 'csin_threshold': 16}\n", "2024-01-30 20:21:45,542 - utils.py[line:55] - INFO: {'distribute': False, 'mixed_precision': True, 'amp_level': 'O2', 'load_ckpt': False}\n" ] }, { "data": { "text/plain": [ "EvolutionNet<\n", " (evo_net): EvolutionNetwork<\n", " (inc): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " (down1): Down<\n", " (maxpool_conv): SequentialCell<\n", " (0): MaxPool2d\n", " (1): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " >\n", " (down2): Down<\n", " (maxpool_conv): SequentialCell<\n", " (0): MaxPool2d\n", " (1): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " >\n", " (down3): Down<\n", " (maxpool_conv): SequentialCell<\n", " (0): MaxPool2d\n", " (1): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " >\n", " (down4): Down<\n", " (maxpool_conv): SequentialCell<\n", " (0): MaxPool2d\n", " (1): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " >\n", " (up1): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up2): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up3): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up4): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (outc): OutConv<\n", " (conv): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (up1_v): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up2_v): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up3_v): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (up4_v): Up<\n", " (up): Upsample<>\n", " (conv): DoubleConv<\n", " (single_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " (double_conv): SequentialCell<\n", " (0): BatchNorm2d\n", " (1): ReLU<>\n", " (2): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " (3): BatchNorm2d\n", " (4): ReLU<>\n", " (5): SpectralNormal<\n", " (parametrizations): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >\n", " >\n", " (outc_v): OutConv<\n", " (conv): Conv2d, bias_init=, format=NCHW>\n", " >\n", " >\n", " >" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logger = get_logger(config)\n", "config[\"model\"][\"module_name\"] = 'evolution'\n", "config[\"data\"][\"batch_size\"] = 4\n", "config[\"summary\"][\"eval_interval\"] = 1\n", "config[\"summary\"][\"visual\"] = False\n", "train_params = config.get(\"train\")\n", "summary_params = config.get(\"summary\")\n", "evo_model = EvolutionNet(config)\n", "evo_model.set_train()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Model Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2024-01-31T03:24:30.164901700Z", "start_time": "2024-01-31T01:01:39.614516800Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 398, loss is 6.6358147\n", "epoch: 1 step: 399, loss is 11.47153\n", "epoch: 1 step: 400, loss is 7.2303286\n", "Train epoch time: 1794325.439 ms, per step time: 4485.814 ms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 09:32:00,659 - forecast.py[line:191] - INFO: ================================Start Evaluation================================\n", "2024-01-31 09:32:00,661 - forecast.py[line:192] - INFO: The length of data is: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 09:32:04,107 - forecast.py[line:179] - INFO: CSI Neighborhood threshold 16 T+10 min: 0.4054458796087876 T+60 min: 0.16474475307855177 T+120 min: 0.09442198339292594\n", "2024-01-31 09:32:04,181 - forecast.py[line:211] - INFO: ================================End Evaluation================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 2 step: 398, loss is 7.7851076\n", "epoch: 2 step: 399, loss is 6.364196\n", "epoch: 2 step: 400, loss is 11.064963\n", "Train epoch time: 1688639.108 ms, per step time: 4221.598 ms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:00:12,836 - forecast.py[line:191] - INFO: ================================Start Evaluation================================\n", "2024-01-31 10:00:12,838 - forecast.py[line:192] - INFO: The length of data is: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\\\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:00:16,213 - forecast.py[line:179] - INFO: CSI Neighborhood threshold 16 T+10 min: 0.41709989523909563 T+60 min: 0.16894114336218546 T+120 min: 0.10467792846088278\n", "2024-01-31 10:00:16,291 - forecast.py[line:211] - INFO: ================================End Evaluation================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 3 step: 398, loss is 5.818244\n", "epoch: 3 step: 399, loss is 6.7248154\n", "epoch: 3 step: 400, loss is 10.864609\n", "Train epoch time: 1690693.724 ms, per step time: 4226.734 ms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:28:27,000 - forecast.py[line:191] - INFO: ================================Start Evaluation================================\n", "2024-01-31 10:28:27,002 - forecast.py[line:192] - INFO: The length of data is: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "|\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:28:30,512 - forecast.py[line:179] - INFO: CSI Neighborhood threshold 16 T+10 min: 0.3993438740760541 T+60 min: 0.16725303802177002 T+120 min: 0.0959103616970071\n", "2024-01-31 10:28:30,583 - forecast.py[line:211] - INFO: ================================End Evaluation================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 4 step: 398, loss is 7.879731\n", "epoch: 4 step: 399, loss is 9.270348\n", "epoch: 4 step: 400, loss is 10.59611\n", "Train epoch time: 1687615.214 ms, per step time: 4219.038 ms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:56:38,214 - forecast.py[line:191] - INFO: ================================Start Evaluation================================\n", "2024-01-31 10:56:38,216 - forecast.py[line:192] - INFO: The length of data is: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 10:56:41,586 - forecast.py[line:179] - INFO: CSI Neighborhood threshold 16 T+10 min: 0.4119070738020246 T+60 min: 0.16328413060990918 T+120 min: 0.10628156308461514\n", "2024-01-31 10:56:41,656 - forecast.py[line:211] - INFO: ================================End Evaluation================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 5 step: 398, loss is 14.705731\n", "epoch: 5 step: 399, loss is 10.468576\n", "epoch: 5 step: 400, loss is 6.882686\n", "Train epoch time: 1691109.624 ms, per step time: 4227.774 ms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 11:24:52,789 - forecast.py[line:191] - INFO: ================================Start Evaluation================================\n", "2024-01-31 11:24:52,790 - forecast.py[line:192] - INFO: The length of data is: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-01-31 11:24:56,256 - forecast.py[line:179] - INFO: CSI Neighborhood threshold 16 T+10 min: 0.4118216017414204 T+60 min: 0.15679251293172677 T+120 min: 0.10144497020636714\n", "2024-01-31 11:24:56,322 - forecast.py[line:211] - INFO: ================================End Evaluation================================\n" ] } ], "source": [ "loss_scale = ms.train.loss_scale_manager.FixedLossScaleManager(loss_scale=2048)\n", "evo_loss_fn = EvolutionLoss(evo_model, config)\n", "trainer = EvolutionTrainer(config, evo_model, evo_loss_fn, logger, loss_scale)\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Evaluation and Visualization\n", "\n", "After training, we use the checkpoint for inference. The visualization of predictions, ground truth and their error is shown below.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2024-01-31T03:27:47.291549200Z", "start_time": "2024-01-31T03:27:46.401222Z" }, "collapsed": false }, "outputs": [], "source": [ "config[\"data\"][\"batch_size\"] = 1\n", "config[\"summary\"][\"visual\"] = True\n", "params = load_checkpoint('./summary/ckpt/evolution-3_200.ckpt')\n", "evo_model.set_train(False)\n", "load_param_into_net(evo_model, params)\n", "evo_inference = EvolutionPredictor(config, evo_model, logger)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2024-01-31T06:11:02.019387Z", "start_time": "2024-01-31T06:11:00.391020800Z" }, "collapsed": false }, "outputs": [], "source": [ "data_params = config.get(\"data\")\n", "test_dataset_generator = RadarData(data_params, run_mode='test', module_name=\"evolution\")\n", "test_dataset = NowcastDataset(test_dataset_generator,\n", " module_name=\"evolution\",\n", " distribute=train_params.get('distribute', False),\n", " num_workers=data_params.get('num_workers', 1),\n", " shuffle=False)\n", "test_dataset = test_dataset.create_dataset(data_params.get('batch_size', 1))\n", "# data = next(test_dataset.create_dict_iterator())\n", "steps = 1\n", "for d in test_dataset.create_dict_iterator():\n", " if steps == 6:\n", " data = d\n", " break\n", " steps += 1\n", "inputs = data['inputs']\n", "pred = evo_inference.forecast(inputs)\n", "labels = inputs[:, data_params.get(\"t_in\"):]\n", "plt_idx = [x // data_params.get(\"time_interval\") - 1 for x in data_params.get(\"key_info_timestep\", [10, 60, 120])]\n", "plt_img(field=pred[0].asnumpy(), label=labels[0].asnumpy(), idx=plt_idx, fig_name=\"./evolution_example.png\")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Generation" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Model Training" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "config[\"model\"][\"module_name\"] = 'generation'\n", "config[\"data\"][\"batch_size\"] = 1\n", "config[\"summary\"][\"visual\"] = False\n", "config[\"summary\"][\"save_checkpoint_epochs\"] = 1\n", "train_params = config.get(\"train\")\n", "summary_params = config.get(\"summary\")\n", "g_model = GenerationNet(config)\n", "d_model = TemporalDiscriminator(data_params.get(\"t_in\", 9) + data_params.get(\"t_out\", 20))\n", "g_model.set_train()\n", "d_model.set_train()\n", "g_model = amp.auto_mixed_precision(g_model, amp_level=train_params.get(\"amp_level\", 'O2'))\n", "d_model = amp.auto_mixed_precision(d_model, amp_level=train_params.get(\"amp_level\", 'O2'))\n", "loss_scale = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, scale_factor=2, scale_window=1000)\n", "g_loss_fn = GenerateLoss(g_model, d_model)\n", "d_loss_fn = DiscriminatorLoss(g_model, d_model)\n", "trainer = GenerationTrainer(config, g_model, d_model, g_loss_fn, d_loss_fn, logger, loss_scale)\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Evaluation and Visualization\n", "\n", "After training, we use the checkpoint for inference. The visualization of predictions, ground truth and their error is shown below." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "config[\"summary\"][\"visual\"] = True\n", "config[\"train\"][\"load_ckpt\"] = True\n", "gen_inference = GenerationPredictor(config, g_model, logger)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "data_params = config.get(\"data\")\n", "model_params = config.get(\"model\")\n", "test_dataset_generator = RadarData(data_params, run_mode='test', module_name=\"generation\")\n", "test_dataset = NowcastDataset(test_dataset_generator,\n", " module_name=\"generation\",\n", " distribute=train_params.get('distribute', False),\n", " num_workers=data_params.get('num_workers', 1),\n", " shuffle=False)\n", "test_dataset = test_dataset.create_dataset(data_params.get('batch_size', 1))\n", "data = next(test_dataset.create_dict_iterator())\n", "inp, evo_result, labels = data.get(\"inputs\"), data.get(\"evo\"), data.get(\"labels\")\n", "noise_scale = data_params.get(\"noise_scale\", 32)\n", "threshold = summary_params.get(\"csin_threshold\", 16)\n", "batch_size = data_params.get(\"batch_size\", 1)\n", "w_size = data_params.get(\"w_size\", 512)\n", "h_size = data_params.get(\"h_size\", 512)\n", "ngf = model_params.get(\"ngf\", 32)\n", "noise = ms.tensor(ms.numpy.randn((batch_size, ngf, h_size // noise_scale, w_size // noise_scale)), inp.dtype)\n", "pred = gen_inference.generator(inp, evo_result, noise)\n", "plt_img(field=pred[0].asnumpy(), label=labels[0].asnumpy(), idx=plt_idx, fig_name=\"./generation_example.png\", evo=evo_result[0].asnumpy() * 128, plot_evo=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }