代码
开发者说 | 基于昇思MindSpore实现LDM扩散模型

开发者说 | 基于昇思MindSpore实现LDM扩散模型

开发者说 | 基于昇思MindSpore实现LDM扩散模型

作者:Adream

来源:昇思论坛

昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者Adrem输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。

# 01

概述

LDM(Latent Diffusion Model)是一种基于扩散过程的先进生成模型,其核心思想是通过逐步去除图像中的噪声来生成高质量的图像样本。与传统扩散模型不同,LDM 引入了潜在空间的概念,将图像表示为潜在空间中的向量。通过学习潜在空间中的映射关系,结合预训练好的变分自编码器(VAE),实现高效的图像生成。由于 LDM 在潜在空间进行操作,特征维度远小于图像空间,因此推理速度相较于基于 DDPM 的模型有显著提升。

LDM 的主要特点包括:

**1、VAE 潜在表示学习:**利用 VAE 学习图像的潜在表示,在压缩数据的同时保留关键特征,为后续的扩散过程提供低维高效的输入。

**2、UNet 映射学习:**采用 UNet 架构学习从潜在空间到图像空间的映射关系,充分利用 UNet 的多尺度特征提取和重建能力。

**3、条件扩散生成:**支持条件扩散过程,能够根据给定的条件(如文本、图像等)生成特定类型的图像,增强了模型的可控性。

**4、分层扩散机制:**运用分层扩散过程,在不同的分辨率层级上进行图像生成,有助于生成高分辨率、细节丰富的图像。

# 02

主要步骤

**1、**数据编码

使用预训练的 VAE 编码器将原始高维图像数据映射到低维潜在空间。这一过程不仅降低了计算复杂度,还提取了图像的关键特征,为后续的扩散过程提供了更紧凑的表示。

**2、**潜在空间扩散

在潜在空间中执行与 DDPM 类似的扩散过程。通过逐步向潜在表示添加高斯噪声,将其从原始数据分布转化为纯噪声分布。在反向过程中,通过 UNet 逐步去除噪声,生成新的潜在样本。

3、数据重建

使用预训练好的 VAE 解码器将去噪后的潜在表示映射回原始图像空间,得到最终的生成图像。

# 03

关键点

**1、**感知压缩与潜在空间优化

LDM 通过预训练的自动编码器将高维图像数据进行感知压缩,映射到低维潜在空间。例如,将 256×256 的图像压缩为 16×16 的潜在表示,大幅减少了计算量。为了控制潜在空间的分布,LDM 采用了以下两种正则化技术:

  • **KL 正则化:**约束潜在变量接近标准正态分布,类似于 VAE 的做法,增强了生成过程的稳定性。
  • **VQ 正则化:**通过向量量化层对潜在表示进行离散化,类似于 VQ - VAE,有助于提升模型对结构化特征的学习能力。

**2、**扩散过程与去噪机制

LDM 在潜在空间中进行扩散过程,分为前向扩散和反向去噪两个阶段:

  • **前向扩散:**逐步向潜在变量添加高斯噪声,噪声强度随时间步递增(如采用线性或余弦调度),最终使潜在表示接近纯噪声分布。
  • **反向去噪:**使用 UNet 预测当前潜在表示中的噪声,并逐步去除噪声。目标函数为预测噪声与实际噪声的均方误差,通过优化该目标函数提升模型的去噪能力。

**3、**多模态条件生成

LDM 引入了交叉注意力机制,支持文本、图像、语义地图等多模态条件输入。具体实现方式如下:

  • **条件编码器:**将不同模态的输入(如文本、图像)映射为中间表示,再与 UNet 的中间层进行交互。
  • **注意力融合:**通过查询(Q)、键(K)、值(V)机制,将条件信息融入生成过程,实现对图像生成的精准控制,例如根据文本描述生成相应的图像。

# 04

模型结构

**1、**像素空间与潜在空间的转换

  • **编码器(Encoder):**由卷积网络构成,通过下采样和特征提取操作,将输入图像 x从高维像素空间压缩到低维潜在空间,生成潜在表示 z。

# 05

MindSpo****re代码实现


import mindspore as ms
from mindspore import nn, ops, Tensor, Parameter
import mindspore.numpy as mnp
import numpy as np

# 超参数配置
config = {
    "latent_dim": 64,          # 潜在空间维度
    "image_size": 64,          # 输入图像尺寸
    "batch_size": 32,          # 批次大小
    "timesteps": 1000,         # 扩散时间步数
    "lr": 1e-4,                # 学习率
    "channels": [64, 128, 256],# UNet通道数
    "condition_dim": 128       # 条件嵌入维度
}

# 2. 编码器 - 解码器模块
class Encoder(nn.Cell):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='same', padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        # 对于 64x64 输入,特征图尺寸为 64x16x16
        self.fc = nn.Dense(64 * 16 * 16, config["latent_dim"])
   
    def construct(self, x):
        x = ops.relu(self.bn1(self.conv1(x)))
        x = self.pool(ops.relu(self.bn2(self.conv2(x))))
        x = x.view(x.shape[0], -1)
        return self.fc(x)

class Decoder(nn.Cell):
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(config["latent_dim"], 64 * 16 * 16)
        self.conv1 = nn.Conv2d(64, 32, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 3, 3, pad_mode='same', padding=1)
        self.upsample = nn.ResizeNearestNeighbor((32, 32))
    
    def construct(self, z):
        x = self.fc(z).view(-1, 64, 16, 16)
        x = self.upsample(ops.relu(self.bn1(self.conv1(x))))
        return self.conv2(x)

# 3. 时间编码模块
class TimeEmbedding(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.embed = nn.Embedding(config["timesteps"], dim)
        self.proj = nn.Dense(dim, dim)
    
    def construct(self, t):
        return self.proj(self.embed(t))

# 4. 交叉注意力模块
class CrossAttention(nn.Cell):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Dense(config["latent_dim"], config["condition_dim"])
        self.k_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
        self.v_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
        self.scale = mnp.sqrt(mnp.asarray(config["condition_dim"], dtype=ms.float32))
   
    def construct(self, z, condition):
        q = self.q_proj(z)
        k = self.k_proj(condition)
        v = self.v_proj(condition)
        attn = ops.matmul(q, k.transpose(0, 1)) / self.scale
        attn = ops.softmax(attn, axis=-1)
        return ops.matmul(attn, v)

# 5. UNet 去噪网络
class UNetBlock(nn.Cell):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, pad_mode='same', padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def construct(self, x):
        residual = x
        x = ops.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return x + residual  # 残差连接

class UNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.time_emb = TimeEmbedding(config["latent_dim"])
        self.down1 = UNetBlock(3, config["channels"][0])
        self.down2 = UNetBlock(config["channels"][0], config["channels"][1])
        self.down3 = UNetBlock(config["channels"][1], config["channels"][2])
        self.up2 = UNetBlock(config["channels"][1] + config["channels"][2], config["channels"][1])
        self.up1 = UNetBlock(config["channels"][0] + config["channels"][1], config["channels"][0])
        self.final = nn.Conv2d(config["channels"][0], 3, 3, pad_mode='same', padding=1)
   
    def construct(self, x, t, condition):
        # 时间条件注入
        t_emb = self.time_emb(t)
        t_emb = ops.broadcast_to(t_emb[:, None, None, :], (x.shape[0], x.shape[1], x.shape[2], t_emb.shape[1]))
        x = ops.concat((x, t_emb), axis=-1)
       
        # 下采样路径
        x1 = self.down1(x)
        x2 = self.down2(ops.max_pool2d(x1, 2))
        x3 = self.down3(ops.max_pool2d(x2, 2))
      
        # 上采样路径
        x = self.up2(ops.concat((ops.resize_nearest_neighbor(x3, scale=2), x2), axis=1))
        x = self.up1(ops.concat((ops.resize_nearest_neighbor(x, scale=2), x1), axis=1))
       
        # 交叉注意力融合条件
        attn = CrossAttention()(x.view(x.shape[0], -1, x.shape[-1]), condition)
        x = x + attn.view(x.shape)
       
        return self.final(x)

# 6. 完整 LDM 模型
class LDM(nn.Cell):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.unet = UNet()
        self.cross_attn = CrossAttention()
  
    def construct(self, x, t, condition):
        # 1. 编码到潜在空间
        z = self.encoder(x)
      
        # 2. 添加噪声(使用线性噪声调度)
        beta = 1e-4 + (0.02 - 1e-4) * (t / config["timesteps"])
        noise = ops.randn(z.shape) * mnp.sqrt(beta)
        z_noisy = z + noise
        
        # 3. UNet 去噪
        # 调整形状以适应 UNet 输入
        denoised = self.unet(z_noisy.view(-1, 64, 8, 8), t, condition)
       
        # 4. 解码回图像空间
        return self.decoder(denoised)

参考链接

[1] 论文地址:https://arxiv.org/pdf/2112.10752