{ "cells": [ { "cell_type": "markdown", "id": "bb7af2dd-61ff-4033-bc04-3a1ec598f5c4", "metadata": {}, "source": [ "# PreDiff: 基于潜在扩散模型的降水短时预报\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindearth/zh_cn/nowcasting/mindspore_prediffnet.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindearth/zh_cn/nowcasting/mindspore_prediffnet.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindearth/docs/source_zh_cn/nowcasting/prediffnet.ipynb)" ] }, { "cell_type": "markdown", "id": "fdd0a324", "metadata": {}, "source": [ "## 概述\n", "\n", "传统的天气预报技术依赖于复杂的物理模型,这些模型不仅计算成本高昂,还要求深厚的专业知识支撑。然而,近十年来,随着地球时空观测数据的爆炸式增长,深度学习技术为构建数据驱动的预测模型开辟了新的道路。虽然这些模型在多种地球系统预测任务中展现出巨大潜力,但它们在管理不确定性和整合特定领域先验知识方面仍有不足,时常导致预测结果模糊不清或在物理上不可信。\n", "\n", "为克服这些难题,来自香港科技大学的Gao Zhihan实现了**PreDiff**模型,专门用于实现概率性的时空预测。该流程融合了条件潜在扩散模型与显式的知识对齐机制,旨在生成既符合特定领域物理约束,又能精确捕捉时空变化的预测结果。通过这种方法,我们期望能够显著提升地球系统预测的准确性和可靠性。模型框架图如下图所示(图片来源于论文 [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS)):\n", "\n", "![prediff](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/docs/mindearth/docs/source_zh_cn/nowcasting/images/prediffnet.jpg)\n", "\n", "训练的过程中,数据通过变分自编码器提取关键信息到隐空间,之后随机选择时间步生成对应时间步噪声,对数据进行加噪处理。之后将数据输入Earthformer-UNet进行降噪处理,Earthformer-UNet采用了UNet构架和Cuboid Attention,去除了Earthformer中连接Encoder和Decoder的Cross Attention结构。最后,将结果通过变分自解码器还原得到去噪后的数据,扩散模型通过反转预先定义的破坏原始数据的加噪过程,来学习数据分布。" ] }, { "cell_type": "markdown", "id": "d3ca5c0b-4fe8-4764-b046-b154df49cc9b", "metadata": {}, "source": [ "## 技术路径\n", "\n", "MindSpore Earth求解该问题的具体流程如下:\n", "\n", "1. 创建数据集\n", "2. 模型构建\n", "3. 损失函数\n", "4. 模型训练\n", "5. 模型评估与可视化\n", "\n", "数据集可以在[PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)下载数据并保存。" ] }, { "cell_type": "code", "execution_count": 1, "id": "56631948-30aa-4360-84d1-e42bff3ab87c", "metadata": {}, "outputs": [], "source": [ "import time\n", "import os\n", "import random\n", "import json\n", "from typing import Sequence, Union\n", "import numpy as np\n", "from einops import rearrange\n", "\n", "import mindspore as ms\n", "from mindspore import set_seed, context, ops, nn, mint\n", "from mindspore.experimental import optim\n", "from mindspore.train.serialization import save_checkpoint" ] }, { "cell_type": "markdown", "id": "415eb386-b9ef-42af-9ab6-f98c9d8151da", "metadata": {}, "source": [ "下述src可以在[PreDiff/src](https://gitee.com/mindspore/mindscience/tree/r0.7/MindEarth/applications/nowcasting/PreDiff/src)下载。" ] }, { "cell_type": "code", "execution_count": 2, "id": "3ab31256-d802-4c68-9066-6c2cc9e73dcd", "metadata": {}, "outputs": [], "source": [ "from mindearth.utils import load_yaml_config\n", "\n", "from src import (\n", " prepare_output_directory,\n", " configure_logging_system,\n", " prepare_dataset,\n", " init_model,\n", " PreDiffModule,\n", " DiffusionTrainer,\n", " DiffusionInferrence\n", ")\n", "from src.sevir_dataset import SEVIRDataset\n", "from src.visual import vis_sevir_seq\n", "from src.utils import warmup_lambda" ] }, { "cell_type": "code", "execution_count": 3, "id": "e2576e2e-b98f-4791-8634-d68e55531ede", "metadata": {}, "outputs": [], "source": [ "set_seed(0)\n", "np.random.seed(0)\n", "random.seed(0)" ] }, { "cell_type": "markdown", "id": "7272ed00-61ef-439c-a420-b3bcefc13965", "metadata": {}, "source": [ "可以在[配置文件](https://gitee.com/mindspore/mindscience/tree/r0.7/MindEarth/applications/nowcasting/PreDiff/configs)中配置模型、数据和优化器等参数。" ] }, { "cell_type": "code", "execution_count": 4, "id": "6c2a2194-7a48-4412-b4d8-9bcae2d5c280", "metadata": {}, "outputs": [], "source": [ "config = load_yaml_config(\"./configs/diffusion.yaml\")\n", "context.set_context(mode=ms.PYNATIVE_MODE)\n", "ms.set_device(device_target=\"Ascend\", device_id=0)" ] }, { "cell_type": "markdown", "id": "9e4b94cd-f806-4e90-bf67-030e66253274", "metadata": {}, "source": [ "## 模型构建\n", "\n", "模型初始化主要包括变分自编码器和earthformer的初始化。" ] }, { "cell_type": "code", "execution_count": 5, "id": "36dabe74-de56-40c2-9db6-370ad2d7a0fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:32:11,466 - utils.py[line:820] - INFO: Process ID: 2231351\n", "2025-04-07 10:32:11,467 - utils.py[line:821] - INFO: {'summary_dir': './summary/prediff/single_device0', 'eval_interval': 10, 'save_ckpt_epochs': 1, 'keep_ckpt_max': 100, 'ckpt_path': '', 'load_ckpt': False}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "NoisyCuboidTransformerEncoder param_not_load: []\n", "Cleared previous output directory: ./summary/prediff/single_device0\n" ] } ], "source": [ "main_module = PreDiffModule(oc_file=\"./configs/diffusion.yaml\")\n", "main_module = init_model(module=main_module, config=config, mode=\"train\")\n", "output_dir = prepare_output_directory(config, \"0\")\n", "logger = configure_logging_system(output_dir, config)" ] }, { "cell_type": "markdown", "id": "ef546194-7eef-46a6-8ffa-504d4b58fa25", "metadata": {}, "source": [ "## 创建数据集\n", "\n", "下载[sevir-lr](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)数据集到./dataset目录。" ] }, { "cell_type": "code", "execution_count": 6, "id": "d4b45562-49b1-4ab3-ab52-bad420c30236", "metadata": {}, "outputs": [], "source": [ "dm, total_num_steps = prepare_dataset(config, PreDiffModule)" ] }, { "cell_type": "markdown", "id": "0fc4c0ce-9870-43f5-931b-30c8b5705122", "metadata": {}, "source": [ "## 损失函数\n", "\n", "PreDiff训练中使用mse作为loss计算,采用了梯度裁剪,并将过程封装在了DiffusionTrainer中。" ] }, { "cell_type": "code", "execution_count": 7, "id": "4e8ff7ff-3641-4a1a-b412-6587c5b09562", "metadata": {}, "outputs": [], "source": [ "class DiffusionTrainer(nn.Cell):\n", " \"\"\"\n", " Class managing the training pipeline for diffusion models. Handles dataset processing,\n", " optimizer configuration, gradient clipping, checkpoint saving, and logging.\n", " \"\"\"\n", " def __init__(self, main_module, dm, logger, config):\n", " \"\"\"\n", " Initialize trainer with model, data module, logger, and configuration.\n", " Args:\n", " main_module: Main diffusion model to be trained\n", " dm: Data module providing training dataset\n", " logger: Logging utility for training progress\n", " config: Configuration dictionary containing hyperparameters\n", " \"\"\"\n", " super().__init__()\n", " self.main_module = main_module\n", " self.traindataset = dm.sevir_train\n", " self.logger = logger\n", " self.datasetprocessing = SEVIRDataset(\n", " data_types=[\"vil\"],\n", " layout=\"NHWT\",\n", " rescale_method=config.get(\"rescale_method\", \"01\"),\n", " )\n", " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", " self.fs = config[\"eval\"].get(\"fs\", 20)\n", " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", " self.current_epoch = 0\n", " self.learn_logvar = (\n", " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", " )\n", " self.logvar = main_module.logvar\n", " self.maeloss = nn.MAELoss()\n", " self.optim_config = config[\"optim\"]\n", " self.clip_norm = config.get(\"clip_norm\", 2.0)\n", " self.ckpt_dir = os.path.join(self.example_save_dir, \"ckpt\")\n", " self.keep_ckpt_max = config[\"summary\"].get(\"keep_ckpt_max\", 100)\n", " self.ckpt_history = []\n", " self.grad_clip_fn = ops.clip_by_global_norm\n", " self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), learning_rate=0.00001)\n", " os.makedirs(self.ckpt_dir, exist_ok=True)\n", "\n", " def train(self, total_steps: int):\n", " \"\"\"Execute complete training pipeline.\"\"\"\n", " self.main_module.main_model.set_train(True)\n", " self.logger.info(\"Initializing training process...\")\n", " loss_processor = Trainonestepforward(self.main_module)\n", " grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params())\n", " for epoch in range(self.optim_config[\"max_epochs\"]):\n", " epoch_loss = 0.0\n", " epoch_start = time.time()\n", "\n", " iterator = self.traindataset.create_dict_iterator()\n", " assert iterator, \"dataset is empty\"\n", " batch_idx = 0\n", " for batch_idx, batch in enumerate(iterator):\n", " processed_data = self.datasetprocessing.process_data(batch[\"vil\"])\n", " loss_value, gradients = grad_func(processed_data)\n", " clipped_grads = self.grad_clip_fn(gradients, self.clip_norm)\n", " self.optimizer(clipped_grads)\n", " epoch_loss += loss_value.asnumpy()\n", " self.logger.info(\n", " f\"epoch: {epoch} step: {batch_idx}, loss: {loss_value}\"\n", " )\n", " self._save_ckpt(epoch)\n", " epoch_time = time.time() - epoch_start\n", " self.logger.info(\n", " f\"Epoch {epoch} completed in {epoch_time:.2f}s | \"\n", " f\"Avg Loss: {epoch_loss/(batch_idx+1):.4f}\"\n", " )\n", "\n", " def _get_optimizer(self, total_steps: int):\n", " \"\"\"Configure optimization components\"\"\"\n", " trainable_params = list(self.main_module.main_model.trainable_params())\n", " if self.learn_logvar:\n", " self.logger.info(\"Including log variance parameters\")\n", " trainable_params.append(self.logvar)\n", " optimizer = optim.AdamW(\n", " trainable_params,\n", " lr=self.optim_config[\"lr\"],\n", " betas=tuple(self.optim_config[\"betas\"]),\n", " )\n", " warmup_steps = int(self.optim_config[\"warmup_percentage\"] * total_steps)\n", " scheduler = self._create_lr_scheduler(optimizer, total_steps, warmup_steps)\n", "\n", " return optimizer, scheduler\n", "\n", " def _create_lr_scheduler(self, optimizer, total_steps: int, warmup_steps: int):\n", " \"\"\"Build learning rate scheduler\"\"\"\n", " warmup_scheduler = optim.lr_scheduler.LambdaLR(\n", " optimizer,\n", " lr_lambda=warmup_lambda(\n", " warmup_steps=warmup_steps,\n", " min_lr_ratio=self.optim_config[\"warmup_min_lr_ratio\"],\n", " ),\n", " )\n", "\n", " cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer,\n", " T_max=total_steps - warmup_steps,\n", " eta_min=self.optim_config[\"min_lr_ratio\"] * self.optim_config[\"lr\"],\n", " )\n", "\n", " return optim.lr_scheduler.SequentialLR(\n", " optimizer,\n", " schedulers=[warmup_scheduler, cosine_scheduler],\n", " milestones=[warmup_steps],\n", " )\n", "\n", " def _save_ckpt(self, epoch: int):\n", " \"\"\"Save model ckpt with rotation policy\"\"\"\n", " ckpt_file = f\"diffusion_epoch{epoch}.ckpt\"\n", " ckpt_path = os.path.join(self.ckpt_dir, ckpt_file)\n", "\n", " save_checkpoint(self.main_module.main_model, ckpt_path)\n", " self.ckpt_history.append(ckpt_path)\n", "\n", " if len(self.ckpt_history) > self.keep_ckpt_max:\n", " removed_ckpt = self.ckpt_history.pop(0)\n", " os.remove(removed_ckpt)\n", " self.logger.info(f\"Removed outdated ckpt: {removed_ckpt}\")\n", "\n", "\n", "class Trainonestepforward(nn.Cell):\n", " \"\"\"A neural network cell that performs one training step forward pass for a diffusion model.\n", " This class encapsulates the forward pass computation for training a diffusion model,\n", " handling the input processing, latent space encoding, conditioning, and loss calculation.\n", " Args:\n", " model (nn.Cell): The main diffusion model containing the necessary submodules\n", " for encoding, conditioning, and loss computation.\n", " \"\"\"\n", "\n", " def __init__(self, model):\n", " super().__init__()\n", " self.main_module = model\n", "\n", " def construct(self, inputs):\n", " \"\"\"Perform one forward training step and compute the loss.\"\"\"\n", " x, condition = self.main_module.get_input(inputs)\n", " x = x.transpose(0, 1, 4, 2, 3)\n", " n, t_, c_, h, w = x.shape\n", " x = x.reshape(n * t_, c_, h, w)\n", " z = self.main_module.encode_first_stage(x)\n", " _, c_z, h_z, w_z = z.shape\n", " z = z.reshape(n, -1, c_z, h_z, w_z)\n", " z = z.transpose(0, 1, 3, 4, 2)\n", " t = ops.randint(0, self.main_module.num_timesteps, (n,)).long()\n", " zc = self.main_module.cond_stage_forward(condition)\n", " loss = self.main_module.p_losses(z, zc, t, noise=None)\n", " return loss" ] }, { "cell_type": "markdown", "id": "da1c23d7-fbff-4492-af19-79f30bcc0185", "metadata": {}, "source": [ "## 模型训练\n", "\n", "在本教程中,我们使用DiffusionTrainer对模型进行训练。" ] }, { "cell_type": "code", "execution_count": 8, "id": "024a0222-c7c7-4cd0-9b6f-7350b92619af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:32:36,351 - 4106154625.py[line:46] - INFO: Initializing training process...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "........." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:09,378 - 4106154625.py[line:64] - INFO: epoch: 0 step: 0, loss: 1.0008465\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:16,871 - 4106154625.py[line:64] - INFO: epoch: 0 step: 1, loss: 1.0023363\n", "2025-04-07 10:34:18,724 - 4106154625.py[line:64] - INFO: epoch: 0 step: 2, loss: 1.0009086\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:20,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 3, loss: 0.99787366\n", "2025-04-07 10:34:22,280 - 4106154625.py[line:64] - INFO: epoch: 0 step: 4, loss: 0.9979043\n", "2025-04-07 10:34:24,072 - 4106154625.py[line:64] - INFO: epoch: 0 step: 5, loss: 0.99897844\n", "2025-04-07 10:34:25,864 - 4106154625.py[line:64] - INFO: epoch: 0 step: 6, loss: 1.0021904\n", "2025-04-07 10:34:27,709 - 4106154625.py[line:64] - INFO: epoch: 0 step: 7, loss: 0.9984627\n", "2025-04-07 10:34:29,578 - 4106154625.py[line:64] - INFO: epoch: 0 step: 8, loss: 0.9952746\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 10:34:31,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 9, loss: 1.0003254\n", "2025-04-07 10:34:33,402 - 4106154625.py[line:64] - INFO: epoch: 0 step: 10, loss: 1.0020428\n", "2025-04-07 10:34:35,218 - 4106154625.py[line:64] - INFO: epoch: 0 step: 11, loss: 0.99563503\n", "2025-04-07 10:34:37,149 - 4106154625.py[line:64] - INFO: epoch: 0 step: 12, loss: 0.99336195\n", "2025-04-07 10:34:38,949 - 4106154625.py[line:64] - INFO: epoch: 0 step: 13, loss: 1.0023757\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "......" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 13:39:55,859 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1247, loss: 0.021378823\n", "2025-04-07 13:39:57,754 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1248, loss: 0.01565772\n", "2025-04-07 13:39:59,606 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1249, loss: 0.012067624\n", "2025-04-07 13:40:01,396 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1250, loss: 0.017700804\n", "2025-04-07 13:40:03,181 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1251, loss: 0.06254268\n", "2025-04-07 13:40:04,945 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1252, loss: 0.013293369\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 13:40:06,770 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1253, loss: 0.026906993\n", "2025-04-07 13:40:08,644 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1254, loss: 0.18210539\n", "2025-04-07 13:40:10,593 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1255, loss: 0.024170894\n", "2025-04-07 13:40:12,430 - 4106154625.py[line:69] - INFO: Epoch 4 completed in 2274.61s | Avg Loss: 0.0517\n" ] } ], "source": [ "trainer = DiffusionTrainer(\n", " main_module=main_module, dm=dm, logger=logger, config=config\n", ")\n", "trainer.train(total_steps=total_num_steps)" ] }, { "cell_type": "markdown", "id": "5be19a94-d51c-487f-bfde-8a9d9dc57d79", "metadata": {}, "source": [ "## 模型评估与可视化\n", "\n", "完成训练后,我们使用第5个ckpt进行推理。下述展示了预测值与实际值之间的误差和各项指标。" ] }, { "cell_type": "code", "execution_count": 14, "id": "e0654f48-d02d-4258-b86c-61da318cf10e", "metadata": {}, "outputs": [], "source": [ "def get_alignment_kwargs_avg_x(target_seq):\n", " \"\"\"Generate alignment parameters for guided sampling\"\"\"\n", " batch_size = target_seq.shape[0]\n", " avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True)\n", " return {\"avg_x_gt\": avg_intensity * 2.0}\n", "\n", "\n", "class DiffusionInferrence(nn.Cell):\n", " \"\"\"\n", " Class managing model inference and evaluation processes. Handles loading checkpoints,\n", " generating predictions, calculating evaluation metrics, and saving visualization results.\n", " \"\"\"\n", " def __init__(self, main_module, dm, logger, config):\n", " \"\"\"\n", " Initialize inference manager with model, data module, logger, and configuration.\n", " Args:\n", " main_module: Main diffusion model for inference\n", " dm: Data module providing test dataset\n", " logger: Logging utility for evaluation progress\n", " config: Configuration dictionary containing evaluation parameters\n", " \"\"\"\n", " super().__init__()\n", " self.num_samples = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", " self.eval_example_only = config[\"eval\"].get(\"eval_example_only\", True)\n", " self.alignment_type = (\n", " config.get(\"model\", {}).get(\"align\", {}).get(\"alignment_type\", \"avg_x\")\n", " )\n", " self.use_alignment = self.alignment_type is not None\n", " self.eval_aligned = config[\"eval\"].get(\"eval_aligned\", True)\n", " self.eval_unaligned = config[\"eval\"].get(\"eval_unaligned\", True)\n", " self.num_samples_per_context = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", " self.logging_prefix = config[\"logging\"].get(\"logging_prefix\", \"PreDiff\")\n", " self.test_example_data_idx_list = [48]\n", " self.main_module = main_module\n", " self.testdataset = dm.sevir_test\n", " self.logger = logger\n", " self.datasetprocessing = SEVIRDataset(\n", " data_types=[\"vil\"],\n", " layout=\"NHWT\",\n", " rescale_method=config.get(\"rescale_method\", \"01\"),\n", " )\n", " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", "\n", " self.fs = config[\"eval\"].get(\"fs\", 20)\n", " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", "\n", " self.current_epoch = 0\n", "\n", " self.learn_logvar = (\n", " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", " )\n", " self.logvar = main_module.logvar\n", " self.maeloss = nn.MAELoss()\n", " self.test_metrics = {\n", " \"step\": 0,\n", " \"mse\": 0.0,\n", " \"mae\": 0.0,\n", " \"ssim\": 0.0,\n", " \"mse_kc\": 0.0,\n", " \"mae_kc\": 0.0,\n", " }\n", "\n", " def test(self):\n", " \"\"\"Execute complete evaluation pipeline.\"\"\"\n", " self.logger.info(\"============== Start Test ==============\")\n", " self.start_time = time.time()\n", " for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()):\n", " self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics)\n", "\n", " self._finalize_test(self.test_metrics)\n", "\n", " def _test_onestep(self, item, batch_idx, metrics):\n", " \"\"\"Process one test batch and update evaluation metrics.\"\"\"\n", " data_idx = int(batch_idx * 2)\n", " if not self._should_test_onestep(data_idx):\n", " return metrics\n", " data = item.get(\"vil\")\n", " data = self.datasetprocessing.process_data(data)\n", " target_seq, cond, context_seq = self._get_model_inputs(data)\n", " aligned_preds, unaligned_preds = self._generate_predictions(\n", " cond, target_seq\n", " )\n", " metrics = self._update_metrics(\n", " aligned_preds, unaligned_preds, target_seq, metrics\n", " )\n", " self._plt_pred(\n", " data_idx,\n", " context_seq,\n", " target_seq,\n", " aligned_preds,\n", " unaligned_preds,\n", " metrics[\"step\"],\n", " )\n", "\n", " metrics[\"step\"] += 1\n", " return metrics\n", "\n", " def _should_test_onestep(self, data_idx):\n", " \"\"\"Determine if evaluation should be performed on current data index.\"\"\"\n", " return (not self.eval_example_only) or (\n", " data_idx in self.test_example_data_idx_list\n", " )\n", "\n", " def _get_model_inputs(self, data):\n", " \"\"\"Extract and prepare model inputs from raw data.\"\"\"\n", " target_seq, cond, context_seq = self.main_module.get_input(\n", " data, return_verbose=True\n", " )\n", " return target_seq, cond, context_seq\n", "\n", " def _generate_predictions(self, cond, target_seq):\n", " \"\"\"Generate both aligned and unaligned predictions from the model.\"\"\"\n", " aligned_preds = []\n", " unaligned_preds = []\n", "\n", " for _ in range(self.num_samples_per_context):\n", " if self.use_alignment and self.eval_aligned:\n", " aligned_pred = self._sample_with_alignment(\n", " cond, target_seq\n", " )\n", " aligned_preds.append(aligned_pred)\n", "\n", " if self.eval_unaligned:\n", " unaligned_pred = self._sample_without_alignment(cond)\n", " unaligned_preds.append(unaligned_pred)\n", "\n", " return aligned_preds, unaligned_preds\n", "\n", " def _sample_with_alignment(self, cond, target_seq):\n", " \"\"\"Generate predictions using alignment mechanism.\"\"\"\n", " alignment_kwargs = get_alignment_kwargs_avg_x(target_seq)\n", " pred_seq = self.main_module.sample(\n", " cond=cond,\n", " batch_size=cond[\"y\"].shape[0],\n", " return_intermediates=False,\n", " use_alignment=True,\n", " alignment_kwargs=alignment_kwargs,\n", " verbose=False,\n", " )\n", " if pred_seq.dtype != ms.float32:\n", " pred_seq = pred_seq.float()\n", " return pred_seq\n", "\n", " def _sample_without_alignment(self, cond):\n", " \"\"\"Generate predictions without alignment.\"\"\"\n", " pred_seq = self.main_module.sample(\n", " cond=cond,\n", " batch_size=cond[\"y\"].shape[0],\n", " return_intermediates=False,\n", " verbose=False,\n", " )\n", " if pred_seq.dtype != ms.float32:\n", " pred_seq = pred_seq.float()\n", " return pred_seq\n", "\n", " def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics):\n", " \"\"\"Update evaluation metrics with new predictions.\"\"\"\n", " for pred in aligned_preds:\n", " metrics[\"mse_kc\"] += ops.mse_loss(pred, target_seq)\n", " metrics[\"mae_kc\"] += self.maeloss(pred, target_seq)\n", " self.main_module.test_aligned_score.update(pred, target_seq)\n", "\n", " for pred in unaligned_preds:\n", " metrics[\"mse\"] += ops.mse_loss(pred, target_seq)\n", " metrics[\"mae\"] += self.maeloss(pred, target_seq)\n", " self.main_module.test_score.update(pred, target_seq)\n", "\n", " pred_bchw = self._convert_to_bchw(pred)\n", " target_bchw = self._convert_to_bchw(target_seq)\n", " metrics[\"ssim\"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0]\n", "\n", " return metrics\n", "\n", " def _convert_to_bchw(self, tensor):\n", " \"\"\"Convert tensor to batch-channel-height-width format for metrics.\"\"\"\n", " return rearrange(tensor.asnumpy(), \"b t h w c -> (b t) c h w\")\n", "\n", " def _plt_pred(\n", " self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step\n", " ):\n", " \"\"\"Generate and save visualization of predictions.\"\"\"\n", " pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds]\n", " pred_labels = [\n", " f\"{self.logging_prefix}_aligned_pred_{i}\" for i in range(len(aligned_preds))\n", " ] + [f\"{self.logging_prefix}_pred_{i}\" for i in range(len(unaligned_preds))]\n", "\n", " self.save_vis_step_end(\n", " data_idx=data_idx,\n", " context_seq=context_seq[0].asnumpy(),\n", " target_seq=target_seq[0].asnumpy(),\n", " pred_seq=pred_sequences,\n", " pred_label=pred_labels,\n", " mode=\"test\",\n", " suffix=f\"_step_{step}\",\n", " )\n", "\n", " def _finalize_test(self, metrics):\n", " \"\"\"Complete test process and log final metrics.\"\"\"\n", " total_time = (time.time() - self.start_time) * 1000\n", " self.logger.info(f\"test cost: {total_time:.2f} ms\")\n", " self._compute_total_metrics(metrics)\n", " self.logger.info(\"============== Test Completed ==============\")\n", "\n", " def _compute_total_metrics(self, metrics):\n", " \"\"\"log_metrics\"\"\"\n", " step_count = max(metrics[\"step\"], 1)\n", " if self.eval_unaligned:\n", " self.logger.info(f\"MSE: {metrics['mse'] / step_count}\")\n", " self.logger.info(f\"MAE: {metrics['mae'] / step_count}\")\n", " self.logger.info(f\"SSIM: {metrics['ssim'] / step_count}\")\n", " test_score = self.main_module.test_score.eval()\n", " self.logger.info(\"SCORE:\\n%s\", json.dumps(test_score, indent=4))\n", " if self.use_alignment:\n", " self.logger.info(f\"KC_MSE: {metrics['mse_kc'] / step_count}\")\n", " self.logger.info(f\"KC_MAE: {metrics['mae_kc'] / step_count}\")\n", " aligned_score = self.main_module.test_aligned_score.eval()\n", " self.logger.info(\"KC_SCORE:\\n%s\", json.dumps(aligned_score, indent=4))\n", "\n", " def save_vis_step_end(\n", " self,\n", " data_idx: int,\n", " context_seq: np.ndarray,\n", " target_seq: np.ndarray,\n", " pred_seq: Union[np.ndarray, Sequence[np.ndarray]],\n", " pred_label: Union[str, Sequence[str]] = None,\n", " mode: str = \"train\",\n", " prefix: str = \"\",\n", " suffix: str = \"\",\n", " ):\n", " \"\"\"Save visualization of predictions with context and target.\"\"\"\n", " example_data_idx_list = self.test_example_data_idx_list\n", " if isinstance(pred_seq, Sequence):\n", " seq_list = [context_seq, target_seq] + list(pred_seq)\n", " label_list = [\"context\", \"target\"] + pred_label\n", " else:\n", " seq_list = [context_seq, target_seq, pred_seq]\n", " label_list = [\"context\", \"target\", pred_label]\n", " if data_idx in example_data_idx_list:\n", " png_save_name = f\"{prefix}{mode}_data_{data_idx}{suffix}.png\"\n", " vis_sevir_seq(\n", " save_path=os.path.join(self.example_save_dir, png_save_name),\n", " seq=seq_list,\n", " label=label_list,\n", " interval_real_time=10,\n", " plot_stride=1,\n", " fs=self.fs,\n", " label_offset=self.label_offset,\n", " label_avg_int=self.label_avg_int,\n", " )" ] }, { "cell_type": "code", "execution_count": 15, "id": "bf830d99-eec2-473e-8ffa-900fc2314b22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 14:04:16,558 - 2610859736.py[line:66] - INFO: ============== Start Test ==============\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[]\n", ".." ] }, { "name": "stdout", "output_type": "stream", "text": [ "2025-04-07 14:10:31,931 - 2610859736.py[line:201] - INFO: test cost: 375371.60 ms\n", "2025-04-07 14:10:31,937 - 2610859736.py[line:215] - INFO: KC_MSE: 0.0036273836\n", "2025-04-07 14:10:31,939 - 2610859736.py[line:216] - INFO: KC_MAE: 0.017427118\n", "2025-04-07 14:10:31,955 - 2610859736.py[line:218] - INFO: KC_SCORE:\n", "{\n", " \"16\": {\n", " \"csi\": 0.2715393900871277,\n", " \"pod\": 0.5063194632530212,\n", " \"sucr\": 0.369321346282959,\n", " \"bias\": 3.9119162559509277\n", " },\n", " \"74\": {\n", " \"csi\": 0.15696434676647186,\n", " \"pod\": 0.17386901378631592,\n", " \"sucr\": 0.6175059080123901,\n", " \"bias\": 0.16501028835773468\n", " }\n", "}\n", "2025-04-07 14:10:31,956 - 2610859736.py[line:203] - INFO: ============== Test Completed ==============\n" ] } ], "source": [ "main_module.main_model.set_train(False)\n", "params = ms.load_checkpoint(\"/PreDiff/summary/prediff/single_device0/ckpt/diffusion_epoch4.ckpt\")\n", "a, b = ms.load_param_into_net(main_module.main_model, params)\n", "print(b)\n", "tester = DiffusionInferrence(\n", " main_module=main_module, dm=dm, logger=logger, config=config\n", " )\n", "tester.test()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }