{ "cells": [ { "cell_type": "markdown", "id": "bb7af2dd-61ff-4033-bc04-3a1ec598f5c4", "metadata": {}, "source": [ "# PreDiff: Precipitation Nowcasting Based on Latent Diffusion Models\n", "\n", "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_notebook_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindearth/en/nowcasting/mindspore_prediffnet.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_download_code_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindearth/en/nowcasting/mindspore_prediffnet.py) [![ViewSource](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindearth/docs/source_en/nowcasting/prediffnet.ipynb)" ] }, { "cell_type": "markdown", "id": "40bf4c9f", "metadata": {}, "source": [ "## Overview\n", "\n", "Traditional weather forecasting techniques rely on complex physical models.These models are not only computationally expensive but also require profound professional knowledge as support. However, in the past decade, with the explosive growth of Earth's spatiotemporal observation data, deep learning techniques have opened up new avenues for constructing data-driven prediction models. Although these models have demonstrated great potential in various Earth system prediction tasks, they still have deficiencies in managing uncertainties and integrating domain-specific prior knowledge, often leading to ambiguous or physically implausible prediction results.\n", "\n", "To overcome these challenges, Gao Zhihan from the Hong Kong University of Science and Technology has implemented the **PreDiff** model, which is specifically designed to achieve probabilistic spatiotemporal prediction. This process integrates a conditional latent diffusion model with an explicit knowledge alignment mechanism, aiming to generate prediction results that not only conform to the physical constraints of the specific domain but also accurately capture spatiotemporal variations. Through this approach, we expect to significantly improve the accuracy and reliability of Earth system predictions. The model framework diagram is shown below (the image is from the paper [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS)):\n", "\n", "![prediff](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/docs/mindearth/docs/source_en/nowcasting/images/prediffnet.jpg)\n", "\n", "During the training process, the variational auto-encoder extracts key information from the data into the latent space. Then, a time step is randomly selected to generate the corresponding noise for that time step, and the data is then augmented with this noise. Subsequently, the noisy data is fed into the Earthformer-UNet for denoising. The Earthformer-UNet employs the UNet architecture and Cuboid Attention, and removes the Cross Attention structure that connects the Encoder and Decoder in the original Earthformer. Finally, the denoised data is obtained by restoring the results through the variational auto-decoder. The diffusion model learns the data distribution by reversing the pre-defined noise-adding process that corrupts the original data." ] }, { "cell_type": "markdown", "id": "d3ca5c0b-4fe8-4764-b046-b154df49cc9b", "metadata": {}, "source": [ "## Technology Path\n", "\n", "The specific process of MindSpore Earth in solving this problem is as follows:\n", "\n", "1. Create a dataset.\n", "2. Construct the model.\n", "3. Define the loss function.\n", "4. Train the model.\n", "5. Evaluate and visualize the model.\n", "\n", "The dataset can be downloaded from [PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip) and saved." ] }, { "cell_type": "code", "execution_count": 1, "id": "56631948-30aa-4360-84d1-e42bff3ab87c", "metadata": {}, "outputs": [], "source": [ "import time\n", "import os\n", "import random\n", "import json\n", "from typing import Sequence, Union\n", "import numpy as np\n", "from einops import rearrange\n", "\n", "import mindspore as ms\n", "from mindspore import set_seed, context, ops, nn, mint\n", "from mindspore.experimental import optim\n", "from mindspore.train.serialization import save_checkpoint" ] }, { "cell_type": "markdown", "id": "415eb386-b9ef-42af-9ab6-f98c9d8151da", "metadata": {}, "source": [ "The following src can be downloaded from [PreDiff/src](https://gitee.com/mindspore/mindscience/tree/r0.7/MindEarth/applications/nowcasting/PreDiff/src)." ] }, { "cell_type": "code", "execution_count": 2, "id": "3ab31256-d802-4c68-9066-6c2cc9e73dcd", "metadata": {}, "outputs": [], "source": [ "from mindearth.utils import load_yaml_config\n", "\n", "from src import (\n", " prepare_output_directory,\n", " configure_logging_system,\n", " prepare_dataset,\n", " init_model,\n", " PreDiffModule,\n", " DiffusionTrainer,\n", " DiffusionInferrence\n", ")\n", "from src.sevir_dataset import SEVIRDataset\n", "from src.visual import vis_sevir_seq\n", "from src.utils import warmup_lambda" ] }, { "cell_type": "code", "execution_count": 3, "id": "e2576e2e-b98f-4791-8634-d68e55531ede", "metadata": {}, "outputs": [], "source": [ "set_seed(0)\n", "np.random.seed(0)\n", "random.seed(0)" ] }, { "cell_type": "markdown", "id": "7272ed00-61ef-439c-a420-b3bcefc13965", "metadata": {}, "source": [ "The parameters of the model, data, optimizer, etc. can be configured in the [configuration file](https://gitee.com/mindspore/mindscience/tree/r0.7/MindEarth/applications/nowcasting/PreDiff/configs)." ] }, { "cell_type": "code", "execution_count": 4, "id": "6c2a2194-7a48-4412-b4d8-9bcae2d5c280", "metadata": {}, "outputs": [], "source": [ "config = load_yaml_config(\"./configs/diffusion.yaml\")\n", "context.set_context(mode=ms.PYNATIVE_MODE)\n", "ms.set_device(device_target=\"Ascend\", device_id=0)" ] }, { "cell_type": "markdown", "id": "9e4b94cd-f806-4e90-bf67-030e66253274", "metadata": {}, "source": [ "## Model Construction\n", "\n", "The model initialization mainly includes the initialization of the variational autoencoder and the earthformer." ] }, { "cell_type": "code", "execution_count": 5, "id": "36dabe74-de56-40c2-9db6-370ad2d7a0fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:32:11,466 - utils.py[line:820] - INFO: Process ID: 2231351\n", "2025-04-07 10:32:11,467 - utils.py[line:821] - INFO: {'summary_dir': './summary/prediff/single_device0', 'eval_interval': 10, 'save_ckpt_epochs': 1, 'keep_ckpt_max': 100, 'ckpt_path': '', 'load_ckpt': False}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "NoisyCuboidTransformerEncoder param_not_load: []\n", "Cleared previous output directory: ./summary/prediff/single_device0\n" ] } ], "source": [ "main_module = PreDiffModule(oc_file=\"./configs/diffusion.yaml\")\n", "main_module = init_model(module=main_module, config=config, mode=\"train\")\n", "output_dir = prepare_output_directory(config, \"0\")\n", "logger = configure_logging_system(output_dir, config)" ] }, { "cell_type": "markdown", "id": "ef546194-7eef-46a6-8ffa-504d4b58fa25", "metadata": {}, "source": [ "## Creating a Dataset\n", "\n", "Download the [sevir-lr](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip) dataset to the `./dataset` directory." ] }, { "cell_type": "code", "execution_count": 6, "id": "d4b45562-49b1-4ab3-ab52-bad420c30236", "metadata": {}, "outputs": [], "source": [ "dm, total_num_steps = prepare_dataset(config, PreDiffModule)" ] }, { "cell_type": "markdown", "id": "0fc4c0ce-9870-43f5-931b-30c8b5705122", "metadata": {}, "source": [ "## Loss Function\n", "\n", "During the training of PreDiff, the mean squared error (mse) is used as the loss calculation. Gradient clipping is adopted, and the process is encapsulated in the DiffusionTrainer." ] }, { "cell_type": "code", "execution_count": 7, "id": "4e8ff7ff-3641-4a1a-b412-6587c5b09562", "metadata": {}, "outputs": [], "source": [ "class DiffusionTrainer(nn.Cell):\n", " \"\"\"\n", " Class managing the training pipeline for diffusion models. Handles dataset processing,\n", " optimizer configuration, gradient clipping, checkpoint saving, and logging.\n", " \"\"\"\n", " def __init__(self, main_module, dm, logger, config):\n", " \"\"\"\n", " Initialize trainer with model, data module, logger, and configuration.\n", " Args:\n", " main_module: Main diffusion model to be trained\n", " dm: Data module providing training dataset\n", " logger: Logging utility for training progress\n", " config: Configuration dictionary containing hyperparameters\n", " \"\"\"\n", " super().__init__()\n", " self.main_module = main_module\n", " self.traindataset = dm.sevir_train\n", " self.logger = logger\n", " self.datasetprocessing = SEVIRDataset(\n", " data_types=[\"vil\"],\n", " layout=\"NHWT\",\n", " rescale_method=config.get(\"rescale_method\", \"01\"),\n", " )\n", " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", " self.fs = config[\"eval\"].get(\"fs\", 20)\n", " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", " self.current_epoch = 0\n", " self.learn_logvar = (\n", " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", " )\n", " self.logvar = main_module.logvar\n", " self.maeloss = nn.MAELoss()\n", " self.optim_config = config[\"optim\"]\n", " self.clip_norm = config.get(\"clip_norm\", 2.0)\n", " self.ckpt_dir = os.path.join(self.example_save_dir, \"ckpt\")\n", " self.keep_ckpt_max = config[\"summary\"].get(\"keep_ckpt_max\", 100)\n", " self.ckpt_history = []\n", " self.grad_clip_fn = ops.clip_by_global_norm\n", " self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), learning_rate=0.00001)\n", " os.makedirs(self.ckpt_dir, exist_ok=True)\n", "\n", " def train(self, total_steps: int):\n", " \"\"\"Execute complete training pipeline.\"\"\"\n", " self.main_module.main_model.set_train(True)\n", " self.logger.info(\"Initializing training process...\")\n", " loss_processor = Trainonestepforward(self.main_module)\n", " grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params())\n", " for epoch in range(self.optim_config[\"max_epochs\"]):\n", " epoch_loss = 0.0\n", " epoch_start = time.time()\n", "\n", " iterator = self.traindataset.create_dict_iterator()\n", " assert iterator, \"dataset is empty\"\n", " batch_idx = 0\n", " for batch_idx, batch in enumerate(iterator):\n", " processed_data = self.datasetprocessing.process_data(batch[\"vil\"])\n", " loss_value, gradients = grad_func(processed_data)\n", " clipped_grads = self.grad_clip_fn(gradients, self.clip_norm)\n", " self.optimizer(clipped_grads)\n", " epoch_loss += loss_value.asnumpy()\n", " self.logger.info(\n", " f\"epoch: {epoch} step: {batch_idx}, loss: {loss_value}\"\n", " )\n", " self._save_ckpt(epoch)\n", " epoch_time = time.time() - epoch_start\n", " self.logger.info(\n", " f\"Epoch {epoch} completed in {epoch_time:.2f}s | \"\n", " f\"Avg Loss: {epoch_loss/(batch_idx+1):.4f}\"\n", " )\n", "\n", " def _get_optimizer(self, total_steps: int):\n", " \"\"\"Configure optimization components\"\"\"\n", " trainable_params = list(self.main_module.main_model.trainable_params())\n", " if self.learn_logvar:\n", " self.logger.info(\"Including log variance parameters\")\n", " trainable_params.append(self.logvar)\n", " optimizer = optim.AdamW(\n", " trainable_params,\n", " lr=self.optim_config[\"lr\"],\n", " betas=tuple(self.optim_config[\"betas\"]),\n", " )\n", " warmup_steps = int(self.optim_config[\"warmup_percentage\"] * total_steps)\n", " scheduler = self._create_lr_scheduler(optimizer, total_steps, warmup_steps)\n", "\n", " return optimizer, scheduler\n", "\n", " def _create_lr_scheduler(self, optimizer, total_steps: int, warmup_steps: int):\n", " \"\"\"Build learning rate scheduler\"\"\"\n", " warmup_scheduler = optim.lr_scheduler.LambdaLR(\n", " optimizer,\n", " lr_lambda=warmup_lambda(\n", " warmup_steps=warmup_steps,\n", " min_lr_ratio=self.optim_config[\"warmup_min_lr_ratio\"],\n", " ),\n", " )\n", "\n", " cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer,\n", " T_max=total_steps - warmup_steps,\n", " eta_min=self.optim_config[\"min_lr_ratio\"] * self.optim_config[\"lr\"],\n", " )\n", "\n", " return optim.lr_scheduler.SequentialLR(\n", " optimizer,\n", " schedulers=[warmup_scheduler, cosine_scheduler],\n", " milestones=[warmup_steps],\n", " )\n", "\n", " def _save_ckpt(self, epoch: int):\n", " \"\"\"Save model ckpt with rotation policy\"\"\"\n", " ckpt_file = f\"diffusion_epoch{epoch}.ckpt\"\n", " ckpt_path = os.path.join(self.ckpt_dir, ckpt_file)\n", "\n", " save_checkpoint(self.main_module.main_model, ckpt_path)\n", " self.ckpt_history.append(ckpt_path)\n", "\n", " if len(self.ckpt_history) > self.keep_ckpt_max:\n", " removed_ckpt = self.ckpt_history.pop(0)\n", " os.remove(removed_ckpt)\n", " self.logger.info(f\"Removed outdated ckpt: {removed_ckpt}\")\n", "\n", "\n", "class Trainonestepforward(nn.Cell):\n", " \"\"\"A neural network cell that performs one training step forward pass for a diffusion model.\n", " This class encapsulates the forward pass computation for training a diffusion model,\n", " handling the input processing, latent space encoding, conditioning, and loss calculation.\n", " Args:\n", " model (nn.Cell): The main diffusion model containing the necessary submodules\n", " for encoding, conditioning, and loss computation.\n", " \"\"\"\n", "\n", " def __init__(self, model):\n", " super().__init__()\n", " self.main_module = model\n", "\n", " def construct(self, inputs):\n", " \"\"\"Perform one forward training step and compute the loss.\"\"\"\n", " x, condition = self.main_module.get_input(inputs)\n", " x = x.transpose(0, 1, 4, 2, 3)\n", " n, t_, c_, h, w = x.shape\n", " x = x.reshape(n * t_, c_, h, w)\n", " z = self.main_module.encode_first_stage(x)\n", " _, c_z, h_z, w_z = z.shape\n", " z = z.reshape(n, -1, c_z, h_z, w_z)\n", " z = z.transpose(0, 1, 3, 4, 2)\n", " t = ops.randint(0, self.main_module.num_timesteps, (n,)).long()\n", " zc = self.main_module.cond_stage_forward(condition)\n", " loss = self.main_module.p_losses(z, zc, t, noise=None)\n", " return loss" ] }, { "cell_type": "markdown", "id": "da1c23d7-fbff-4492-af19-79f30bcc0185", "metadata": {}, "source": [ "## Model Training\n", "\n", "In this tutorial, we use the DiffusionTrainer to train the model." ] }, { "cell_type": "code", "execution_count": 8, "id": "024a0222-c7c7-4cd0-9b6f-7350b92619af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:32:36,351 - 4106154625.py[line:46] - INFO: Initializing training process...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "........." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:09,378 - 4106154625.py[line:64] - INFO: epoch: 0 step: 0, loss: 1.0008465\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:16,871 - 4106154625.py[line:64] - INFO: epoch: 0 step: 1, loss: 1.0023363\n", "2025-04-07 10:34:18,724 - 4106154625.py[line:64] - INFO: epoch: 0 step: 2, loss: 1.0009086\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:20,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 3, loss: 0.99787366\n", "2025-04-07 10:34:22,280 - 4106154625.py[line:64] - INFO: epoch: 0 step: 4, loss: 0.9979043\n", "2025-04-07 10:34:24,072 - 4106154625.py[line:64] - INFO: epoch: 0 step: 5, loss: 0.99897844\n", "2025-04-07 10:34:25,864 - 4106154625.py[line:64] - INFO: epoch: 0 step: 6, loss: 1.0021904\n", "2025-04-07 10:34:27,709 - 4106154625.py[line:64] - INFO: epoch: 0 step: 7, loss: 0.9984627\n", "2025-04-07 10:34:29,578 - 4106154625.py[line:64] - INFO: epoch: 0 step: 8, loss: 0.9952746\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:31,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 9, loss: 1.0003254\n", "2025-04-07 10:34:33,402 - 4106154625.py[line:64] - INFO: epoch: 0 step: 10, loss: 1.0020428\n", "2025-04-07 10:34:35,218 - 4106154625.py[line:64] - INFO: epoch: 0 step: 11, loss: 0.99563503\n", "2025-04-07 10:34:37,149 - 4106154625.py[line:64] - INFO: epoch: 0 step: 12, loss: 0.99336195\n", "2025-04-07 10:34:38,949 - 4106154625.py[line:64] - INFO: epoch: 0 step: 13, loss: 1.0023757\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "......" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 13:39:55,859 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1247, loss: 0.021378823\n", "2025-04-07 13:39:57,754 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1248, loss: 0.01565772\n", "2025-04-07 13:39:59,606 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1249, loss: 0.012067624\n", "2025-04-07 13:40:01,396 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1250, loss: 0.017700804\n", "2025-04-07 13:40:03,181 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1251, loss: 0.06254268\n", "2025-04-07 13:40:04,945 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1252, loss: 0.013293369\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 13:40:06,770 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1253, loss: 0.026906993\n", "2025-04-07 13:40:08,644 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1254, loss: 0.18210539\n", "2025-04-07 13:40:10,593 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1255, loss: 0.024170894\n", "2025-04-07 13:40:12,430 - 4106154625.py[line:69] - INFO: Epoch 4 completed in 2274.61s | Avg Loss: 0.0517\n" ] } ], "source": [ "trainer = DiffusionTrainer(\n", " main_module=main_module, dm=dm, logger=logger, config=config\n", ")\n", "trainer.train(total_steps=total_num_steps)" ] }, { "cell_type": "markdown", "id": "5be19a94-d51c-487f-bfde-8a9d9dc57d79", "metadata": {}, "source": [ "## Model Evaluation and Visualization\n", "\n", "After the training is completed, we use the 5th ckpt for inference. The following shows the error between the predicted values and the actual values and various indicators." ] }, { "cell_type": "code", "execution_count": 14, "id": "e0654f48-d02d-4258-b86c-61da318cf10e", "metadata": {}, "outputs": [], "source": [ "def get_alignment_kwargs_avg_x(target_seq):\n", " \"\"\"Generate alignment parameters for guided sampling\"\"\"\n", " batch_size = target_seq.shape[0]\n", " avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True)\n", " return {\"avg_x_gt\": avg_intensity * 2.0}\n", "\n", "\n", "class DiffusionInferrence(nn.Cell):\n", " \"\"\"\n", " Class managing model inference and evaluation processes. Handles loading checkpoints,\n", " generating predictions, calculating evaluation metrics, and saving visualization results.\n", " \"\"\"\n", " def __init__(self, main_module, dm, logger, config):\n", " \"\"\"\n", " Initialize inference manager with model, data module, logger, and configuration.\n", " Args:\n", " main_module: Main diffusion model for inference\n", " dm: Data module providing test dataset\n", " logger: Logging utility for evaluation progress\n", " config: Configuration dictionary containing evaluation parameters\n", " \"\"\"\n", " super().__init__()\n", " self.num_samples = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", " self.eval_example_only = config[\"eval\"].get(\"eval_example_only\", True)\n", " self.alignment_type = (\n", " config.get(\"model\", {}).get(\"align\", {}).get(\"alignment_type\", \"avg_x\")\n", " )\n", " self.use_alignment = self.alignment_type is not None\n", " self.eval_aligned = config[\"eval\"].get(\"eval_aligned\", True)\n", " self.eval_unaligned = config[\"eval\"].get(\"eval_unaligned\", True)\n", " self.num_samples_per_context = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", " self.logging_prefix = config[\"logging\"].get(\"logging_prefix\", \"PreDiff\")\n", " self.test_example_data_idx_list = [48]\n", " self.main_module = main_module\n", " self.testdataset = dm.sevir_test\n", " self.logger = logger\n", " self.datasetprocessing = SEVIRDataset(\n", " data_types=[\"vil\"],\n", " layout=\"NHWT\",\n", " rescale_method=config.get(\"rescale_method\", \"01\"),\n", " )\n", " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", "\n", " self.fs = config[\"eval\"].get(\"fs\", 20)\n", " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", "\n", " self.current_epoch = 0\n", "\n", " self.learn_logvar = (\n", " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", " )\n", " self.logvar = main_module.logvar\n", " self.maeloss = nn.MAELoss()\n", " self.test_metrics = {\n", " \"step\": 0,\n", " \"mse\": 0.0,\n", " \"mae\": 0.0,\n", " \"ssim\": 0.0,\n", " \"mse_kc\": 0.0,\n", " \"mae_kc\": 0.0,\n", " }\n", "\n", " def test(self):\n", " \"\"\"Execute complete evaluation pipeline.\"\"\"\n", " self.logger.info(\"============== Start Test ==============\")\n", " self.start_time = time.time()\n", " for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()):\n", " self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics)\n", "\n", " self._finalize_test(self.test_metrics)\n", "\n", " def _test_onestep(self, item, batch_idx, metrics):\n", " \"\"\"Process one test batch and update evaluation metrics.\"\"\"\n", " data_idx = int(batch_idx * 2)\n", " if not self._should_test_onestep(data_idx):\n", " return metrics\n", " data = item.get(\"vil\")\n", " data = self.datasetprocessing.process_data(data)\n", " target_seq, cond, context_seq = self._get_model_inputs(data)\n", " aligned_preds, unaligned_preds = self._generate_predictions(\n", " cond, target_seq\n", " )\n", " metrics = self._update_metrics(\n", " aligned_preds, unaligned_preds, target_seq, metrics\n", " )\n", " self._plt_pred(\n", " data_idx,\n", " context_seq,\n", " target_seq,\n", " aligned_preds,\n", " unaligned_preds,\n", " metrics[\"step\"],\n", " )\n", "\n", " metrics[\"step\"] += 1\n", " return metrics\n", "\n", " def _should_test_onestep(self, data_idx):\n", " \"\"\"Determine if evaluation should be performed on current data index.\"\"\"\n", " return (not self.eval_example_only) or (\n", " data_idx in self.test_example_data_idx_list\n", " )\n", "\n", " def _get_model_inputs(self, data):\n", " \"\"\"Extract and prepare model inputs from raw data.\"\"\"\n", " target_seq, cond, context_seq = self.main_module.get_input(\n", " data, return_verbose=True\n", " )\n", " return target_seq, cond, context_seq\n", "\n", " def _generate_predictions(self, cond, target_seq):\n", " \"\"\"Generate both aligned and unaligned predictions from the model.\"\"\"\n", " aligned_preds = []\n", " unaligned_preds = []\n", "\n", " for _ in range(self.num_samples_per_context):\n", " if self.use_alignment and self.eval_aligned:\n", " aligned_pred = self._sample_with_alignment(\n", " cond, target_seq\n", " )\n", " aligned_preds.append(aligned_pred)\n", "\n", " if self.eval_unaligned:\n", " unaligned_pred = self._sample_without_alignment(cond)\n", " unaligned_preds.append(unaligned_pred)\n", "\n", " return aligned_preds, unaligned_preds\n", "\n", " def _sample_with_alignment(self, cond, target_seq):\n", " \"\"\"Generate predictions using alignment mechanism.\"\"\"\n", " alignment_kwargs = get_alignment_kwargs_avg_x(target_seq)\n", " pred_seq = self.main_module.sample(\n", " cond=cond,\n", " batch_size=cond[\"y\"].shape[0],\n", " return_intermediates=False,\n", " use_alignment=True,\n", " alignment_kwargs=alignment_kwargs,\n", " verbose=False,\n", " )\n", " if pred_seq.dtype != ms.float32:\n", " pred_seq = pred_seq.float()\n", " return pred_seq\n", "\n", " def _sample_without_alignment(self, cond):\n", " \"\"\"Generate predictions without alignment.\"\"\"\n", " pred_seq = self.main_module.sample(\n", " cond=cond,\n", " batch_size=cond[\"y\"].shape[0],\n", " return_intermediates=False,\n", " verbose=False,\n", " )\n", " if pred_seq.dtype != ms.float32:\n", " pred_seq = pred_seq.float()\n", " return pred_seq\n", "\n", " def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics):\n", " \"\"\"Update evaluation metrics with new predictions.\"\"\"\n", " for pred in aligned_preds:\n", " metrics[\"mse_kc\"] += ops.mse_loss(pred, target_seq)\n", " metrics[\"mae_kc\"] += self.maeloss(pred, target_seq)\n", " self.main_module.test_aligned_score.update(pred, target_seq)\n", "\n", " for pred in unaligned_preds:\n", " metrics[\"mse\"] += ops.mse_loss(pred, target_seq)\n", " metrics[\"mae\"] += self.maeloss(pred, target_seq)\n", " self.main_module.test_score.update(pred, target_seq)\n", "\n", " pred_bchw = self._convert_to_bchw(pred)\n", " target_bchw = self._convert_to_bchw(target_seq)\n", " metrics[\"ssim\"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0]\n", "\n", " return metrics\n", "\n", " def _convert_to_bchw(self, tensor):\n", " \"\"\"Convert tensor to batch-channel-height-width format for metrics.\"\"\"\n", " return rearrange(tensor.asnumpy(), \"b t h w c -> (b t) c h w\")\n", "\n", " def _plt_pred(\n", " self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step\n", " ):\n", " \"\"\"Generate and save visualization of predictions.\"\"\"\n", " pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds]\n", " pred_labels = [\n", " f\"{self.logging_prefix}_aligned_pred_{i}\" for i in range(len(aligned_preds))\n", " ] + [f\"{self.logging_prefix}_pred_{i}\" for i in range(len(unaligned_preds))]\n", "\n", " self.save_vis_step_end(\n", " data_idx=data_idx,\n", " context_seq=context_seq[0].asnumpy(),\n", " target_seq=target_seq[0].asnumpy(),\n", " pred_seq=pred_sequences,\n", " pred_label=pred_labels,\n", " mode=\"test\",\n", " suffix=f\"_step_{step}\",\n", " )\n", "\n", " def _finalize_test(self, metrics):\n", " \"\"\"Complete test process and log final metrics.\"\"\"\n", " total_time = (time.time() - self.start_time) * 1000\n", " self.logger.info(f\"test cost: {total_time:.2f} ms\")\n", " self._compute_total_metrics(metrics)\n", " self.logger.info(\"============== Test Completed ==============\")\n", "\n", " def _compute_total_metrics(self, metrics):\n", " \"\"\"log_metrics\"\"\"\n", " step_count = max(metrics[\"step\"], 1)\n", " if self.eval_unaligned:\n", " self.logger.info(f\"MSE: {metrics['mse'] / step_count}\")\n", " self.logger.info(f\"MAE: {metrics['mae'] / step_count}\")\n", " self.logger.info(f\"SSIM: {metrics['ssim'] / step_count}\")\n", " test_score = self.main_module.test_score.eval()\n", " self.logger.info(\"SCORE:\\n%s\", json.dumps(test_score, indent=4))\n", " if self.use_alignment:\n", " self.logger.info(f\"KC_MSE: {metrics['mse_kc'] / step_count}\")\n", " self.logger.info(f\"KC_MAE: {metrics['mae_kc'] / step_count}\")\n", " aligned_score = self.main_module.test_aligned_score.eval()\n", " self.logger.info(\"KC_SCORE:\\n%s\", json.dumps(aligned_score, indent=4))\n", "\n", " def save_vis_step_end(\n", " self,\n", " data_idx: int,\n", " context_seq: np.ndarray,\n", " target_seq: np.ndarray,\n", " pred_seq: Union[np.ndarray, Sequence[np.ndarray]],\n", " pred_label: Union[str, Sequence[str]] = None,\n", " mode: str = \"train\",\n", " prefix: str = \"\",\n", " suffix: str = \"\",\n", " ):\n", " \"\"\"Save visualization of predictions with context and target.\"\"\"\n", " example_data_idx_list = self.test_example_data_idx_list\n", " if isinstance(pred_seq, Sequence):\n", " seq_list = [context_seq, target_seq] + list(pred_seq)\n", " label_list = [\"context\", \"target\"] + pred_label\n", " else:\n", " seq_list = [context_seq, target_seq, pred_seq]\n", " label_list = [\"context\", \"target\", pred_label]\n", " if data_idx in example_data_idx_list:\n", " png_save_name = f\"{prefix}{mode}_data_{data_idx}{suffix}.png\"\n", " vis_sevir_seq(\n", " save_path=os.path.join(self.example_save_dir, png_save_name),\n", " seq=seq_list,\n", " label=label_list,\n", " interval_real_time=10,\n", " plot_stride=1,\n", " fs=self.fs,\n", " label_offset=self.label_offset,\n", " label_avg_int=self.label_avg_int,\n", " )" ] }, { "cell_type": "code", "execution_count": 15, "id": "bf830d99-eec2-473e-8ffa-900fc2314b22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 14:04:16,558 - 2610859736.py[line:66] - INFO: ============== Start Test ==============\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[]\n", ".." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 14:10:31,931 - 2610859736.py[line:201] - INFO: test cost: 375371.60 ms\n", "2025-04-07 14:10:31,937 - 2610859736.py[line:215] - INFO: KC_MSE: 0.0036273836\n", "2025-04-07 14:10:31,939 - 2610859736.py[line:216] - INFO: KC_MAE: 0.017427118\n", "2025-04-07 14:10:31,955 - 2610859736.py[line:218] - INFO: KC_SCORE:\n", "{\n", " \"16\": {\n", " \"csi\": 0.2715393900871277,\n", " \"pod\": 0.5063194632530212,\n", " \"sucr\": 0.369321346282959,\n", " \"bias\": 3.9119162559509277\n", " },\n", " \"74\": {\n", " \"csi\": 0.15696434676647186,\n", " \"pod\": 0.17386901378631592,\n", " \"sucr\": 0.6175059080123901,\n", " \"bias\": 0.16501028835773468\n", " }\n", "}\n", "2025-04-07 14:10:31,956 - 2610859736.py[line:203] - INFO: ============== Test Completed ==============\n" ] } ], "source": [ "main_module.main_model.set_train(False)\n", "params = ms.load_checkpoint(\"/PreDiff/summary/prediff/single_device0/ckpt/diffusion_epoch4.ckpt\")\n", "a, b = ms.load_param_into_net(main_module.main_model, params)\n", "print(b)\n", "tester = DiffusionInferrence(\n", " main_module=main_module, dm=dm, logger=logger, config=config\n", " )\n", "tester.test()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }