代码
扩散模型系列——DDPM

扩散模型系列——DDPM

扩散模型系列——DDPM

DDPM

论文地址:Denoising Diffusion Probabilistic Models

代码地址:https://github.com/hojonathanho/diffusion

一、概述

DDPM 是基于扩散过程的生成模型,通过逐步去除图像中的噪声来生成高质量的样本。其核心思想是将数据分布转换为更容易处理的简单高斯噪声分布,然后使用神经网络学习从该噪声分布中恢复原始数据的映射关系并通过优化变分下界简化训练,最终实现对复杂数据分布的建模。

DDPM 的主要特点包括:

  • 双向马尔可夫链:正向过程固定(可解析求解),反向过程参数化(神经网络学习);
  • 噪声预测机制:通过预测正向过程中添加的噪声间接恢复数据,避免显式估计复杂分布;
  • U-Net 架构:利用多尺度特征和跳跃连接捕捉图像细节,结合时间嵌入感知扩散阶段。

二、主要步骤

1. 正向扩散过程

正向扩散是为了在每个时间步 t中,逐渐增加噪声的比例,将原始数据 x_0转变为带有噪声的数据 x_t = \alpha_t x_0 + (1 - \alpha_t) \epsilon,其中 \alpha_t是扩散系数,\epsilon是标准正态分布的噪声。

  • 单步扩散:每一步添加微小高斯噪声,状态转移分布为:
    q(x_t \mid x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}\right)
    其中 \beta_t \in (0, 1)是预设的方差调度(通常随时间递增,如线性调度 \beta_t = \beta_{\text{start}} + t*(\beta_{\text{end}} - \beta_{\text{start}})/T)。\mathcal{N}(\cdot; \mu, \sigma^2 I)表示均值为 \mu,协方差矩阵为 \sigma^2 I的多元高斯分布。
  • 边际分布:通过递归推导单步扩散公式,任意时刻 t的数据分布可直接由 x_0表示为:
    q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t)I)
    \alpha_t = 1 - \beta_t
    \bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s
    该公式表明, x_t可直接由 x_0和随机噪声 \epsilon \sim \mathcal{N}(0, \mathbf{I})生成:
    x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon

2. 反向去噪过程

反向过程是将噪声分布转化为原始数据分布的过程,其目标是逆转扩散过程,从纯噪声 x_t出发,逐步去噪,最终生成与原始数据分布 p_{data}(x_0)相近的样本 x_0。

  • 条件分布:反向过程的单步转移分布为参数化的高斯分布:
    p_\theta(x_{t-1} \mid x_t) = \mathcal{N}\left(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 \mathbf{I}\right)
    其中均值 \mu_\theta(x_t, t)由神经网络 \epsilon_\theta(x_t, t)预测噪声后推导得到:
    \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right)
    方差 \sigma_t^2可固定为预设值(如 \sigma_t^2 = \beta_t',论文中采用可学习或固定的方差调度)。 同时,p(x_T) = \mathcal{N}(x_T; 0, I)是预先定义的先验分布,通常为标准高斯分布。
  • 训练目标:通过最小化噪声预测的均方误差(MSE),使 \epsilon_\theta(x_t, t)逼近正向过程中添加的真实噪声 \epsilon。

3. 采样过程

从 x_T \sim \mathcal{N}(0, \mathbf{I})开始,迭代应用反向过程:

x_{t-1} = \mu_\theta(x_t, t) + \sigma_t \cdot z, \quad z \sim \mathcal{N}(0, \mathbf{I})

直到生成 x_0,完成样本生成。

三、数学理论

1. 变分推断框架

DDPM 的目标是最大化数据对数似然 \log p_\theta(x_0),通过引入近似后验 q(x_{1:T} \mid x_0)(即正向扩散过程),构建变分下界: $ \log p_\theta(x_0) \geq \mathbb{E}{q(x{1:T}|x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)} \right] =: L_{\text{VLB}} $ 通过展开和简化(推导可见原论文),最终下界可分解为易于计算的噪声预测损失:

L_{\text{simple}} = \mathbb{E}_{t \sim \mathcal{U}(1, T), x_0 \sim q(x_0), \epsilon \sim \mathcal{N}(0, \mathbf{I})} \left[ \|\epsilon - \epsilon_\theta\left( x_t, t \right)\|^2 \right]

其中 x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon是正向过程生成的含噪数据。

2. 关键公式推导

  • 正向过程的可加性:由于每一步添加的噪声独立,x_t可直接由 x_0和累积噪声权重生成,因此无需递归计算每一步 x_{t-1}。
  • 反向均值推导:利用贝叶斯公式和正向过程的高斯性质,将 \mu_\theta(x_t, t)表示为 x_t和预测噪声 \epsilon_\theta的线性组合,避免显式计算复杂的条件分布。
  • 变分下界分解:可分解为:
    L = \mathbb{E}_q \left[ D_{KL}(q(x_T | x_0) \parallel p(x_T)) + \sum_{t=2}^{T} D_{KL}(q(x_{t-1} | x_t, x_0) \parallel p_\theta(x_{t-1} | x_t)) - \log p_\theta(x_0 | x_1) \right]
    其中 D_{KL}为两个高斯分布的 KL 散度。

四、模型结构

DDPM 采用 U-Net 架构参数化反向过程的噪声预测函数 \epsilon_\theta(x_t, t),核心设计包括:

1. 网络架构

  • 编码器(下采样):通过卷积和池化层提取多尺度特征,捕捉图像的全局结构。
  • 解码器(上采样):通过转置卷积和跳跃连接恢复空间分辨率,融合编码器对应尺度的特征,保留细节信息。
  • 时间嵌入(Time Embedding): 将时间步 t编码为高维向量(如正弦位置编码或可学习嵌入),通过线性层映射后与各层特征相加或拼接,使模型感知当前扩散阶段。

2. 关键组件

  • 跳跃连接:直接传递编码器特征到解码器,避免下采样导致的信息丢失,增强细节恢复能力。
  • 自注意力机制:在中等分辨率层级(如 32×32 像素)引入自注意力层,建模长距离依赖关系,提升生成样本的全局一致性。

五、代码实现

import mindspore
from mindspore import nn, ops
import numpy as np

# 时间嵌入模块
class TimestepEmbedding(nn.Cell):
    """将时间步t编码为高维向量"""
    def __init__(self, embed_dim=128):
        super(TimestepEmbedding, self).__init__()
        self.embed_dim = embed_dim
        # 使用可学习的Embedding(替代正弦编码,简化实现)
        self.time_embed = nn.Embedding(1000, embed_dim)  # 假设最大时间步为1000

    def construct(self, t):
        # 输入t为形状(batch_size,)的整数张量
        return self.time_embed(t).view(-1, self.embed_dim, 1, 1)  # 扩展为(batch, dim, 1, 1)用于特征融合

# U-Net 模型(含时间嵌入)
class UNet(nn.Cell):
    """带跳跃连接和时间嵌入的U-Net,用于预测噪声ε_θ(x_t, t)"""
    def __init__(self, in_channels=1, out_channels=1, time_embed_dim=128):
        super(UNet, self).__init__()
        self.time_embedding = TimestepEmbedding(time_embed_dim)

        # 编码器(下采样路径)
        self.enc_conv1 = nn.Conv2d(in_channels, 64, 3, padding=1, pad_mode='pad')
        self.enc_conv2 = nn.Conv2d(64, 64, 3, padding=1, pad_mode='pad')
        self.enc_pool = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Conv2d(64, 128, 3, padding=1, pad_mode='pad')
        self.enc_conv4 = nn.Conv2d(128, 128, 3, padding=1, pad_mode='pad')

        # 解码器(上采样路径)
        self.dec_transconv1 = nn.Conv2dTranspose(128, 64, 2, stride=2)  # 上采样到编码器第一层尺度
        self.dec_conv1 = nn.SequentialCell(
            nn.Conv2d(128, 64, 3, padding=1, pad_mode='pad'),  # 拼接后通道数64+64=128→64
            nn.ReLU()
        )

        self.dec_transconv2 = nn.Conv2dTranspose(64, 32, 2, stride=2)  # 上采样到输入尺度
        self.dec_conv2 = nn.SequentialCell(
            nn.Conv2d(33, 32, 3, padding=1, pad_mode='pad'),  # 输入通道数:32(上采样)+1(输入x)=33→32
            nn.ReLU()
        )

        self.final_conv = nn.Conv2d(32, out_channels, 3, padding=1, pad_mode='pad')  # 输出噪声预测
        self.relu = nn.ReLU()

    def construct(self, x, t):
        # 时间嵌入
        t_emb = self.time_embedding(t)

        # 编码器前向传播
        h1 = self.relu(self.enc_conv1(x))
        h1 = self.relu(self.enc_conv2(h1))  # (batch, 64, H, W)
        h1_pool = self.enc_pool(h1)  # (batch, 64, H/2, W/2)

        h2 = self.relu(self.enc_conv3(h1_pool))
        h2 = self.relu(self.enc_conv4(h2))  # (batch, 128, H/4, W/4)

        # 解码器反向传播
        h3 = self.dec_transconv1(h2)  # (batch, 64, H/2, W/2)
        h3 = ops.concat((h3, h1), axis=1)  # 跳跃连接:拼接编码器同尺度特征(64+64=128通道)
        h3 = self.dec_conv1(h3)  # (batch, 64, H/2, W/2)

        h4 = self.dec_transconv2(h3)  # (batch, 32, H, W)
        h4 = ops.concat((h4, x), axis=1)  # 拼接原始输入x(示例简化,实际应匹配尺度,此处假设输入为单通道)
        h4 = self.dec_conv2(h4)  # (batch, 32, H, W)

        out = self.final_conv(h4)  # 输出预测噪声,形状与输入x一致
        return out

# DDPM 模型主体
class DDPM(nn.Cell):
    """DDPM主类,管理扩散过程和损失计算"""
    def __init__(self, unet, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        super(DDPM, self).__init__()
        self.unet = unet
        self.num_timesteps = num_timesteps

        # 计算扩散参数(使用MindSpore张量,支持自动微分)
        self.betas = mindspore.Tensor(
            np.linspace(beta_start, beta_end, num_timesteps, dtype=np.float32)
        )
        self.alphas = 1. - self.betas
        self.alphas_cumprod = ops.cumprod(self.alphas, 0)  # 累积乘积,形状(num_timesteps,)
        self.sqrt_alphas_cumprod = ops.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = ops.sqrt(1. - self.alphas_cumprod)

    def q_sample(self, x_start, t):
        """根据正向过程生成x_t = sqrt(α_t^bar)x0 + sqrt(1-α_t^bar)ε"""
        sqrt_alpha_prod = ops.gather(self.sqrt_alphas_cumprod, t, 0)  # 提取批次对应的α累积根
        sqrt_one_minus_alpha_prod = ops.gather(self.sqrt_one_minus_alphas_cumprod, t, 0)
        noise = ops.randn_like(x_start)  # 生成随机噪声
        return (sqrt_alpha_prod.view(-1, 1, 1, 1) * x_start +
                sqrt_one_minus_alpha_prod.view(-1, 1, 1, 1) * noise)

    def p_losses(self, x_start, t):
        """计算噪声预测损失:MSE(ε_θ(x_t, t), ε)"""
        noise = ops.randn_like(x_start)  # 真实噪声ε
        x_noisy = self.q_sample(x_start, t)  # 生成含噪数据x_t
        predicted_noise = self.unet(x_noisy, t)  # 模型预测噪声
        return nn.MSELoss()(predicted_noise, noise)

    def construct(self, x):
        """训练时的前向传播:随机采样时间步t,计算损失"""
        batch_size = x.shape[0]
        t = ops.randint(0, self.num_timesteps, (batch_size,), dtype=mindspore.int32)
        return self.p_losses(x, t)