代码
开发者说 | 基于昇思MindSpore实现DDIM扩散模型

开发者说 | 基于昇思MindSpore实现DDIM扩散模型

开发者说 | 基于昇思MindSpore实现DDIM扩散模型

作者:Adream

来源:昇思论坛

昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者Adrem输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。

# 01

概述

DDIM 是基于 DDPM 改进的迭代隐式概率扩撒模型,核心目标是在保持生成质量的同时加速采样过程。通过引入非马尔可夫扩散过程和确定性采样机制,DDIM 允许在去噪时跳过部分时间步,可以显著减少计算量。其核心创新在于:

**1、可调方差参数:**通过控制反向过程的随机性,实现从完全随机(DDPM)到完全确定(无噪声)的采样模式;

**2、跳跃式采样:**无需遍历所有时间步,可直接在预设的关键时间点之间跳转,大幅提升生成速度。

DDIM 的主要特点包括:

**1、****非马尔可夫过程:**打破 DDPM 的严格马尔可夫链限制,允许当前状态依赖任意历史状态;

**2、****确定性采样:**通过设置方差为 0,消除采样过程的随机性,提升生成稳定性;

**3、****采样效率:**支持“跳步”采样,在 10-50 步内即可生成高质量样本(DDPM 需 1000 步)。

# 02

主要步骤

**1、**正向扩散过程

DDIM 的正向扩散过程与 DDPM 一致,都是为了在每个时间步 t 中,逐渐增加噪声的比例,

# 04

模型结构

DDIM 沿用 DDPM 的 U-Net 架构作为主干网络,包含对称的编码器-解码器路径和跳跃连接,但针对采样效率进行了轻量化调整:

**1、**网络设计细节

  • **归一化与激活:**使用GroupNorm替代 BatchNorm 用以提升小批量训练稳定性,使用SiLU 激活函数替代 ReLU,增强非线性建模能力;
  • **时间嵌入:**将时间步 t编码为高维向量(如正弦编码或可学习嵌入),通过线性层与各层特征融合;
  • **跳跃连接:**保留原来的编码器-解码器的多尺度特征融合,确保细节恢复能力。

**2、**关键模块对比

  • **采样层:**DDIM 的p_sample方法通过判断σₜ是否为 0,决定是否添加随机噪声,默认σₜ=0 时为纯确定性计算;
  • **时间步处理:**支持任意时间步跳转,无需按顺序遍历,通过预设的时间步列表(如[T_s, T_{s-1}, ..., T_2, T_1])实现跳步采样。

# 05

MindSpo****re代码实现

# 核心采样逻辑
class DDIM(nn.Cell):
    """DDIM核心类,实现跳跃式确定性采样"""
    def __init__(self, model, betas, T=1000, sample_steps=50):
        super().__init__()
        self.model = model  # U-Net网络
        self.T = T          # 总时间步
        self.sample_steps = sample_steps  # 采样时使用的跳步步长
        self.betas = betas
        self.alphas = 1. - betas
        self.alpha_bars = np.cumprod(self.alphas)
        
        # 生成跳步时间序列(如从T到0,每隔T/sample_steps步取一个点)
        self.sampling_timesteps = np.linspace(0, T-1, sample_steps, dtype=np.int64)[::-1]
   
    def p_sample(self, x, t):
        """确定性去噪单步(σ=0)"""
        alpha = self.alphas[t]
        alpha_bar = self.alpha_bars[t]
        sqrt_alpha = ops.sqrt(alpha)
        sqrt_one_minus_alpha = ops.sqrt(1 - alpha)
       
        # 预测噪声并估计原始数据
        pred_noise = self.model(x, t)
        pred_x0 = (x - sqrt_one_minus_alpha * pred_noise) / sqrt_alpha
        
        # DDIM确定性采样公式
        alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else 1.0
        sqrt_alpha_bar_prev = ops.sqrt(alpha_bar_prev)
        sqrt_one_minus_alpha_bar_prev = ops.sqrt(1 - alpha_bar_prev)
        
        x_prev = sqrt_alpha_bar_prev * pred_x0 + sqrt_one_minus_alpha_bar_prev * pred_noise
        return x_prev
    
    def construct(self, x):
        """跳步采样过程(从x_T到x_0)"""
        for t in self.sampling_timesteps:
            x = self.p_sample(x, t)
        return x
# U-Net 改进
class UNet(nn.Cell):
    """带GroupNorm和SiLU的轻量化U-Net"""
    def __init__(self, in_channels=3, channel_dim=128):
        super().__init__()
        self.time_embed = nn.SequentialCell(
            nn.Embedding(1000, channel_dim),
            nn.SiLU(),
            nn.Dense(channel_dim, channel_dim * 4)
        )
        
        self.down = nn.SequentialCell(
            nn.Conv2d(in_channels, channel_dim, 3, padding=1),
            nn.GroupNorm(32, channel_dim),
            nn.SiLU(),
            nn.Conv2d(channel_dim, channel_dim * 2, 3, padding=1, stride=2),
            nn.GroupNorm(32, channel_dim * 2),
            nn.SiLU()
        )
        
        self.up = nn.SequentialCell(
            nn.Conv2dTranspose(channel_dim * 2, channel_dim, 3, stride=2, padding=1),
            nn.GroupNorm(32, channel_dim),
            nn.SiLU(),
            nn.Conv2d(channel_dim, in_channels, 3, padding=1),
            nn.Tanh()
        )
    
    def construct(self, x, t):
        t_emb = self.time_embed(t)
        h = self.down(x) + t_emb.view(-1, h.shape[1], 1, 1)
        return self.up(h)

参考链接

[1] 论文地址:https://arxiv.org/pdf/2010.02502