开发者说 | 基于昇思MindSpore实现LDM扩散模型
开发者说 | 基于昇思MindSpore实现LDM扩散模型
作者:Adream
来源:昇思论坛
昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者Adrem输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。
# 01
概述
LDM(Latent Diffusion Model)是一种基于扩散过程的先进生成模型,其核心思想是通过逐步去除图像中的噪声来生成高质量的图像样本。与传统扩散模型不同,LDM 引入了潜在空间的概念,将图像表示为潜在空间中的向量。通过学习潜在空间中的映射关系,结合预训练好的变分自编码器(VAE),实现高效的图像生成。由于 LDM 在潜在空间进行操作,特征维度远小于图像空间,因此推理速度相较于基于 DDPM 的模型有显著提升。
LDM 的主要特点包括:
**1、VAE 潜在表示学习:**利用 VAE 学习图像的潜在表示,在压缩数据的同时保留关键特征,为后续的扩散过程提供低维高效的输入。
**2、UNet 映射学习:**采用 UNet 架构学习从潜在空间到图像空间的映射关系,充分利用 UNet 的多尺度特征提取和重建能力。
**3、条件扩散生成:**支持条件扩散过程,能够根据给定的条件(如文本、图像等)生成特定类型的图像,增强了模型的可控性。
**4、分层扩散机制:**运用分层扩散过程,在不同的分辨率层级上进行图像生成,有助于生成高分辨率、细节丰富的图像。
# 02
主要步骤
**1、**数据编码
使用预训练的 VAE 编码器将原始高维图像数据映射到低维潜在空间。这一过程不仅降低了计算复杂度,还提取了图像的关键特征,为后续的扩散过程提供了更紧凑的表示。
**2、**潜在空间扩散
在潜在空间中执行与 DDPM 类似的扩散过程。通过逐步向潜在表示添加高斯噪声,将其从原始数据分布转化为纯噪声分布。在反向过程中,通过 UNet 逐步去除噪声,生成新的潜在样本。
3、数据重建
使用预训练好的 VAE 解码器将去噪后的潜在表示映射回原始图像空间,得到最终的生成图像。
# 03
关键点
**1、**感知压缩与潜在空间优化
LDM 通过预训练的自动编码器将高维图像数据进行感知压缩,映射到低维潜在空间。例如,将 256×256 的图像压缩为 16×16 的潜在表示,大幅减少了计算量。为了控制潜在空间的分布,LDM 采用了以下两种正则化技术:
- **KL 正则化:**约束潜在变量接近标准正态分布,类似于 VAE 的做法,增强了生成过程的稳定性。
- **VQ 正则化:**通过向量量化层对潜在表示进行离散化,类似于 VQ - VAE,有助于提升模型对结构化特征的学习能力。
**2、**扩散过程与去噪机制
LDM 在潜在空间中进行扩散过程,分为前向扩散和反向去噪两个阶段:
- **前向扩散:**逐步向潜在变量添加高斯噪声,噪声强度随时间步递增(如采用线性或余弦调度),最终使潜在表示接近纯噪声分布。
- **反向去噪:**使用 UNet 预测当前潜在表示中的噪声,并逐步去除噪声。目标函数为预测噪声与实际噪声的均方误差,通过优化该目标函数提升模型的去噪能力。
**3、**多模态条件生成
LDM 引入了交叉注意力机制,支持文本、图像、语义地图等多模态条件输入。具体实现方式如下:
- **条件编码器:**将不同模态的输入(如文本、图像)映射为中间表示,再与 UNet 的中间层进行交互。
- **注意力融合:**通过查询(Q)、键(K)、值(V)机制,将条件信息融入生成过程,实现对图像生成的精准控制,例如根据文本描述生成相应的图像。
# 04
模型结构
**1、**像素空间与潜在空间的转换
- **编码器(Encoder):**由卷积网络构成,通过下采样和特征提取操作,将输入图像 x从高维像素空间压缩到低维潜在空间,生成潜在表示 z。

# 05
MindSpo****re代码实现
import mindspore as ms
from mindspore import nn, ops, Tensor, Parameter
import mindspore.numpy as mnp
import numpy as np
# 超参数配置
config = {
"latent_dim": 64, # 潜在空间维度
"image_size": 64, # 输入图像尺寸
"batch_size": 32, # 批次大小
"timesteps": 1000, # 扩散时间步数
"lr": 1e-4, # 学习率
"channels": [64, 128, 256],# UNet通道数
"condition_dim": 128 # 条件嵌入维度
}
# 2. 编码器 - 解码器模块
class Encoder(nn.Cell):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, pad_mode='same', padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='same', padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(2, 2)
# 对于 64x64 输入,特征图尺寸为 64x16x16
self.fc = nn.Dense(64 * 16 * 16, config["latent_dim"])
def construct(self, x):
x = ops.relu(self.bn1(self.conv1(x)))
x = self.pool(ops.relu(self.bn2(self.conv2(x))))
x = x.view(x.shape[0], -1)
return self.fc(x)
class Decoder(nn.Cell):
def __init__(self):
super().__init__()
self.fc = nn.Dense(config["latent_dim"], 64 * 16 * 16)
self.conv1 = nn.Conv2d(64, 32, 3, pad_mode='same', padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 3, 3, pad_mode='same', padding=1)
self.upsample = nn.ResizeNearestNeighbor((32, 32))
def construct(self, z):
x = self.fc(z).view(-1, 64, 16, 16)
x = self.upsample(ops.relu(self.bn1(self.conv1(x))))
return self.conv2(x)
# 3. 时间编码模块
class TimeEmbedding(nn.Cell):
def __init__(self, dim):
super().__init__()
self.embed = nn.Embedding(config["timesteps"], dim)
self.proj = nn.Dense(dim, dim)
def construct(self, t):
return self.proj(self.embed(t))
# 4. 交叉注意力模块
class CrossAttention(nn.Cell):
def __init__(self):
super().__init__()
self.q_proj = nn.Dense(config["latent_dim"], config["condition_dim"])
self.k_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
self.v_proj = nn.Dense(config["condition_dim"], config["condition_dim"])
self.scale = mnp.sqrt(mnp.asarray(config["condition_dim"], dtype=ms.float32))
def construct(self, z, condition):
q = self.q_proj(z)
k = self.k_proj(condition)
v = self.v_proj(condition)
attn = ops.matmul(q, k.transpose(0, 1)) / self.scale
attn = ops.softmax(attn, axis=-1)
return ops.matmul(attn, v)
# 5. UNet 去噪网络
class UNetBlock(nn.Cell):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, pad_mode='same', padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, pad_mode='same', padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def construct(self, x):
residual = x
x = ops.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return x + residual # 残差连接
class UNet(nn.Cell):
def __init__(self):
super().__init__()
self.time_emb = TimeEmbedding(config["latent_dim"])
self.down1 = UNetBlock(3, config["channels"][0])
self.down2 = UNetBlock(config["channels"][0], config["channels"][1])
self.down3 = UNetBlock(config["channels"][1], config["channels"][2])
self.up2 = UNetBlock(config["channels"][1] + config["channels"][2], config["channels"][1])
self.up1 = UNetBlock(config["channels"][0] + config["channels"][1], config["channels"][0])
self.final = nn.Conv2d(config["channels"][0], 3, 3, pad_mode='same', padding=1)
def construct(self, x, t, condition):
# 时间条件注入
t_emb = self.time_emb(t)
t_emb = ops.broadcast_to(t_emb[:, None, None, :], (x.shape[0], x.shape[1], x.shape[2], t_emb.shape[1]))
x = ops.concat((x, t_emb), axis=-1)
# 下采样路径
x1 = self.down1(x)
x2 = self.down2(ops.max_pool2d(x1, 2))
x3 = self.down3(ops.max_pool2d(x2, 2))
# 上采样路径
x = self.up2(ops.concat((ops.resize_nearest_neighbor(x3, scale=2), x2), axis=1))
x = self.up1(ops.concat((ops.resize_nearest_neighbor(x, scale=2), x1), axis=1))
# 交叉注意力融合条件
attn = CrossAttention()(x.view(x.shape[0], -1, x.shape[-1]), condition)
x = x + attn.view(x.shape)
return self.final(x)
# 6. 完整 LDM 模型
class LDM(nn.Cell):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.unet = UNet()
self.cross_attn = CrossAttention()
def construct(self, x, t, condition):
# 1. 编码到潜在空间
z = self.encoder(x)
# 2. 添加噪声(使用线性噪声调度)
beta = 1e-4 + (0.02 - 1e-4) * (t / config["timesteps"])
noise = ops.randn(z.shape) * mnp.sqrt(beta)
z_noisy = z + noise
# 3. UNet 去噪
# 调整形状以适应 UNet 输入
denoised = self.unet(z_noisy.view(-1, 64, 8, 8), t, condition)
# 4. 解码回图像空间
return self.decoder(denoised)
参考链接
[1] 论文地址:https://arxiv.org/pdf/2112.10752