代码
扩散模型系列——LDM

扩散模型系列——LDM

扩散模型系列——LDM

LDM

论文地址:High-Resolution Image Synthesis with Latent Diffusion Models

代码地址:https://github.com/CompVis/latent-diffusion

一、概述

LDM(Latent Diffusion Model)是一种基于扩散过程的先进生成模型,其核心思想是通过逐步去除图像中的噪声来生成高质量的图像样本。与传统扩散模型不同,LDM 引入了潜在空间的概念,将图像表示为潜在空间中的向量。通过学习潜在空间中的映射关系,结合预训练好的变分自编码器(VAE),实现高效的图像生成。由于 LDM 在潜在空间进行操作,特征维度远小于图像空间,因此推理速度相较于基于 DDPM 的模型有显著提升。

LDM 的主要特点包括:

  1. VAE 潜在表示学习:利用 VAE 学习图像的潜在表示,在压缩数据的同时保留关键特征,为后续的扩散过程提供低维高效的输入。
  2. UNet 映射学习:采用 UNet 架构学习从潜在空间到图像空间的映射关系,充分利用 UNet 的多尺度特征提取和重建能力。
  3. 条件扩散生成:支持条件扩散过程,能够根据给定的条件(如文本、图像等)生成特定类型的图像,增强了模型的可控性。
  4. 分层扩散机制:运用分层扩散过程,在不同的分辨率层级上进行图像生成,有助于生成高分辨率、细节丰富的图像。

二、主要步骤

1. 数据编码

使用预训练的 VAE 编码器将原始高维图像数据映射到低维潜在空间。这一过程不仅降低了计算复杂度,还提取了图像的关键特征,为后续的扩散过程提供了更紧凑的表示。

2. 潜在空间扩散

在潜在空间中执行与 DDPM 类似的扩散过程。通过逐步向潜在表示添加高斯噪声,将其从原始数据分布转化为纯噪声分布。在反向过程中,通过 UNet 逐步去除噪声,生成新的潜在样本。

3. 数据重建

使用预训练好的 VAE 解码器将去噪后的潜在表示映射回原始图像空间,得到最终的生成图像。

三、关键点

1. 感知压缩与潜在空间优化

LDM 通过预训练的自动编码器将高维图像数据进行感知压缩,映射到低维潜在空间。例如,将 256×256 的图像压缩为 16×16 的潜在表示,大幅减少了计算量。为了控制潜在空间的分布,LDM 采用了以下两种正则化技术:

  • KL 正则化:约束潜在变量接近标准正态分布,类似于 VAE 的做法,增强了生成过程的稳定性。
  • VQ 正则化:通过向量量化层对潜在表示进行离散化,类似于 VQ - VAE,有助于提升模型对结构化特征的学习能力。

2. 扩散过程与去噪机制

LDM 在潜在空间中进行扩散过程,分为前向扩散和反向去噪两个阶段:

  • 前向扩散:逐步向潜在变量添加高斯噪声,噪声强度随时间步递增(如采用线性或余弦调度),最终使潜在表示接近纯噪声分布。
  • 反向去噪:使用 UNet 预测当前潜在表示中的噪声,并逐步去除噪声。目标函数为预测噪声与实际噪声的均方误差,通过优化该目标函数提升模型的去噪能力。

3. 多模态条件生成

LDM 引入了交叉注意力机制,支持文本、图像、语义地图等多模态条件输入。具体实现方式如下:

  • 条件编码器:将不同模态的输入(如文本、图像)映射为中间表示,再与 UNet 的中间层进行交互。
  • 注意力融合:通过查询(Q)、键(K)、值(V)机制,将条件信息融入生成过程,实现对图像生成的精准控制,例如根据文本描述生成相应的图像。

四、模型结构

1. 像素空间与潜在空间的转换

  • 编码器(Encoder):由卷积网络构成,通过下采样和特征提取操作,将输入图像 x从高维像素空间压缩到低维潜在空间,生成潜在表示 z。
  • 解码器(Decoder):包含上采样和特征重建层,将去噪后的潜在表示 \tilde{z}映射回像素空间,生成重建图像 \tilde{x}。

2. 潜在空间中的扩散过程

  • 前向扩散:在潜在空间中,逐步向潜在表示 z添加高斯噪声,生成一系列噪声版本 z_t(t表示时间步)。
  • 噪声调度:噪声强度随时间步递增,常见的调度方式有线性调度和余弦调度,最终使 z_T接近纯噪声。

3. 去噪 U-Net 核心模块

  • 网络结构:
    • 下采样路径:通过卷积层、批归一化(BN)和激活函数(如 ReLU)逐步提取多尺度特征。
    • 上采样路径:使用反卷积或上采样操作恢复空间分辨率。
    • 跳跃连接:将下采样层的特征直接传递到对应的上采样层,帮助保留图像的细节信息。
  • 时间条件注入:
    • 通过可学习的时间嵌入(如正弦位置编码)将时间步 t编码为向量。
    • 将时间嵌入向量与 UNet 中间特征进行拼接,指导模型根据扩散进度调整去噪策略。
  • 跨注意力机制:
    • 条件融合:将条件信息(如文本、语义图)通过条件模块(如 Transformer)编码为查询(Q)。
    • 潜在表示交互:将潜在表示 z_t作为键(K)和值(V),通过注意力机制与条件查询交互。公式:
      \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) \cdot V

4. 条件输入处理

  • 条件编码器:针对不同模态(如文本、图像)设计专用编码器,将输入映射为固定维度的向量。
  • 与 UNet 交互:条件编码器的输出作为跨注意力层的查询,指导 UNet 的去噪过程,从而实现根据条件生成图像。

5. 去噪迭代过程

从纯噪声 z_T开始,通过 UNet 逐步去噪,每次迭代中 UNet 预测当前噪声并去除,最终生成干净的潜在表示 z_0,逐步逼近真实数据分布。

五、代码实现

import mindspore as ms
from mindspore import nn, ops, Tensor, Parameter
import mindspore.numpy as mnp
import numpy as np

# 超参数配置
config = {
    "latent_dim": 64,          # 潜在空间维度
    "image_size": 64,          # 输入图像尺寸
    "batch_size": 32,          # 批次大小
    "timesteps": 1000,         # 扩散时间步数
    "lr": 1e-4,                # 学习率
    "channels": [64, 128, 256],# UNet通道数
    "condition_dim": 128       # 条件嵌入维度
}

# 2. 编码器 - 解码器模块
class Encoder(nn.Cell):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='same', padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        # 对于 64x64 输入,特征图尺寸为 64x16x16
        self.fc = nn.Dense(64 * 16 * 16, config["latent_dim"])

    def construct(self, x):
        x = ops.relu(self.bn1(self.conv1(x)))
        x = self.pool(ops.relu(self.bn2(self.conv2(x))))
        x = x.view(x.shape[0], -1)
        return self.fc(x)

class Decoder(nn.Cell):
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(config["latent_dim"], 64 * 16 * 16)
        self.conv1 = nn.Conv2d(64, 32, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 3, 3, pad_mode='same', padding=1)
        self.upsample = nn.ResizeNearestNeighbor((32, 32))

    def construct(self, z):
        x = self.fc(z).view(-1, 64, 16, 16)
        x = self.upsample(ops.relu(self.bn1(self.conv1(x))))
        return self.conv2(x)

# 3. 时间编码模块
class TimeEmbedding(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.embed = nn.Embedding(config["timesteps"], dim)
        self.proj = nn.Dense(dim, dim)

    def construct(self, t):
        return self.proj(self.embed(t))

# 4. 交叉注意力模块
class CrossAttention(nn.Cell):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Dense(config["latent_dim"], config["condition_dim"])
        self.k_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
        self.v_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
        self.scale = mnp.sqrt(mnp.asarray(config["condition_dim"], dtype=ms.float32))

    def construct(self, z, condition):
        q = self.q_proj(z)
        k = self.k_proj(condition)
        v = self.v_proj(condition)

        attn = ops.matmul(q, k.transpose(0, 1)) / self.scale
        attn = ops.softmax(attn, axis=-1)
        return ops.matmul(attn, v)

# 5. UNet 去噪网络
class UNetBlock(nn.Cell):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, pad_mode='same', padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, pad_mode='same', padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def construct(self, x):
        residual = x
        x = ops.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return x + residual  # 残差连接

class UNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.time_emb = TimeEmbedding(config["latent_dim"])
        self.down1 = UNetBlock(3, config["channels"][0])
        self.down2 = UNetBlock(config["channels"][0], config["channels"][1])
        self.down3 = UNetBlock(config["channels"][1], config["channels"][2])
        self.up2 = UNetBlock(config["channels"][1] + config["channels"][2], config["channels"][1])
        self.up1 = UNetBlock(config["channels"][0] + config["channels"][1], config["channels"][0])
        self.final = nn.Conv2d(config["channels"][0], 3, 3, pad_mode='same', padding=1)

    def construct(self, x, t, condition):
        # 时间条件注入
        t_emb = self.time_emb(t)
        t_emb = ops.broadcast_to(t_emb[:, None, None, :], (x.shape[0], x.shape[1], x.shape[2], t_emb.shape[1]))
        x = ops.concat((x, t_emb), axis=-1)

        # 下采样路径
        x1 = self.down1(x)
        x2 = self.down2(ops.max_pool2d(x1, 2))
        x3 = self.down3(ops.max_pool2d(x2, 2))

        # 上采样路径
        x = self.up2(ops.concat((ops.resize_nearest_neighbor(x3, scale=2), x2), axis=1))
        x = self.up1(ops.concat((ops.resize_nearest_neighbor(x, scale=2), x1), axis=1))

        # 交叉注意力融合条件
        attn = CrossAttention()(x.view(x.shape[0], -1, x.shape[-1]), condition)
        x = x + attn.view(x.shape)

        return self.final(x)

# 6. 完整 LDM 模型
class LDM(nn.Cell):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.unet = UNet()
        self.cross_attn = CrossAttention()

    def construct(self, x, t, condition):
        # 1. 编码到潜在空间
        z = self.encoder(x)

        # 2. 添加噪声(使用线性噪声调度)
        beta = 1e-4 + (0.02 - 1e-4) * (t / config["timesteps"])
        noise = ops.randn(z.shape) * mnp.sqrt(beta)
        z_noisy = z + noise

        # 3. UNet 去噪
        # 调整形状以适应 UNet 输入
        denoised = self.unet(z_noisy.view(-1, 64, 8, 8), t, condition)

        # 4. 解码回图像空间
        return self.decoder(denoised)