[{"data":1,"prerenderedAt":322},["ShallowReactive",2],{"content-query-Nu5YSOBK8H":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":316,"_id":317,"_source":318,"_file":319,"_stem":320,"_extension":321},"/technology-blogs/zh/3764","zh",false,"","开发者说 | 基于昇思MindSpore实现CFG扩散模型","昇思MindSpore2024年技术帖分享大会圆满结束！","2025-06-17","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2025/06/20/0ecdd5c480e04efb84c72d89868ca625.png","technology-blogs","开发者说",{"type":15,"children":16,"toc":311},"root",[17,25,31,36,41,50,58,63,73,78,83,88,96,104,109,117,125,132,152,160,165,170,178,186,209,217,222,227,235,240,245,253,258,263,271,279,289,298],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"开发者说-基于昇思mindspore实现cfg扩散模型",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":28},"p",{},[29],{"type":24,"value":30},"作者：Adream",{"type":18,"tag":26,"props":32,"children":33},{},[34],{"type":24,"value":35},"来源：昇思论坛",{"type":18,"tag":26,"props":37,"children":38},{},[39],{"type":24,"value":40},"昇思MindSpore2024年技术帖分享大会圆满结束！全年收获80+高质量技术帖， 2025年全新升级，推出“2025年昇思干货小卖部，你投我就收！”，活动继续每月征集技术帖。本期技术文章由社区开发者Adrem输出并投稿。如果您对活动感兴趣，欢迎在昇思论坛投稿。",{"type":18,"tag":26,"props":42,"children":43},{},[44],{"type":18,"tag":45,"props":46,"children":47},"strong",{},[48],{"type":24,"value":49},"# 01",{"type":18,"tag":26,"props":51,"children":52},{},[53],{"type":18,"tag":45,"props":54,"children":55},{},[56],{"type":24,"value":57},"概述",{"type":18,"tag":26,"props":59,"children":60},{},[61],{"type":24,"value":62},"Classifier-Free Guidance（无分类器引导，简称 CFG） 是谷歌于 2022 年提出的扩散模型优化技术，旨在增强生成样本的质量与条件契合度。该方法通过联合训练无条件与有条件生成模型，避免了传统Classifier Guidance中显式分类器的依赖，无需额外训练噪声图像分类器。推理阶段，通过线性组合两种生成模式的预测结果，实现灵活的条件引导，显著提升文本生成图像、图像修复等任务的性能。值得强调的是，CFG 是一种推理优化技术，不改变模型训练目标。本文将详细讲述CFG的原理及如何用昇思MindSpore实现它，并附上代码。",{"type":18,"tag":26,"props":64,"children":65},{},[66,71],{"type":18,"tag":45,"props":67,"children":68},{},[69],{"type":24,"value":70},"Classifier-Free Guidance",{"type":24,"value":72}," 的主要特点包括：",{"type":18,"tag":26,"props":74,"children":75},{},[76],{"type":24,"value":77},"1、**无分类器依赖：**无需额外训练和维护分类器，减少了模型复杂性和计算资源消耗。",{"type":18,"tag":26,"props":79,"children":80},{},[81],{"type":24,"value":82},"2、**联合训练策略：**通过随机丢弃条件信息，使同一网络同时学习无条件与有条件生成；",{"type":18,"tag":26,"props":84,"children":85},{},[86],{"type":24,"value":87},"3、**灵活的引导强度：**引入引导强度超参数 $w$，可灵活调节生成结果对条件的依赖程度。",{"type":18,"tag":26,"props":89,"children":90},{},[91],{"type":18,"tag":45,"props":92,"children":93},{},[94],{"type":24,"value":95},"# 02",{"type":18,"tag":26,"props":97,"children":98},{},[99],{"type":18,"tag":45,"props":100,"children":101},{},[102],{"type":24,"value":103},"主要步骤",{"type":18,"tag":26,"props":105,"children":106},{},[107],{"type":24,"value":108},"CFG的实现分为训练阶段与推理阶段，通过巧妙的参数共享与结果融合达成高效引导。",{"type":18,"tag":26,"props":110,"children":111},{},[112],{"type":18,"tag":45,"props":113,"children":114},{},[115],{"type":24,"value":116},"1、训练阶段",{"type":18,"tag":26,"props":118,"children":119},{},[120],{"type":18,"tag":121,"props":122,"children":124},"img",{"alt":7,"src":123},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2025/06/20/4a125295d2f04c6eae183fe575be30fb.png",[],{"type":18,"tag":26,"props":126,"children":127},{},[128],{"type":18,"tag":121,"props":129,"children":131},{"alt":7,"src":130},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2025/06/20/dc5a0e5e7479422ab4b8dfdd3c36005d.png",[],{"type":18,"tag":133,"props":134,"children":135},"ul",{},[136,142,147],{"type":18,"tag":137,"props":138,"children":139},"li",{},[140],{"type":24,"value":141},"w=1：标准条件生成（如文本对齐图像）。",{"type":18,"tag":137,"props":143,"children":144},{},[145],{"type":24,"value":146},"w>1：增强条件契合度（如更鲜艳的颜色）。",{"type":18,"tag":137,"props":148,"children":149},{},[150],{"type":24,"value":151},"w\u003C1：提升多样性（如抽象艺术风格）。",{"type":18,"tag":26,"props":153,"children":154},{},[155],{"type":18,"tag":45,"props":156,"children":157},{},[158],{"type":24,"value":159},"3、联合训练范式",{"type":18,"tag":26,"props":161,"children":162},{},[163],{"type":24,"value":164},"**随机条件丢弃：**训练时以50%概率随机屏蔽条件输入，强制模型学习数据分布的共性与条件特化；",{"type":18,"tag":26,"props":166,"children":167},{},[168],{"type":24,"value":169},"**参数共享：**无条件与有条件生成共享U-Net网络参数，仅通过条件标记区分输入，避免冗余训练。",{"type":18,"tag":26,"props":171,"children":172},{},[173],{"type":18,"tag":45,"props":174,"children":175},{},[176],{"type":24,"value":177},"# 04",{"type":18,"tag":26,"props":179,"children":180},{},[181],{"type":18,"tag":45,"props":182,"children":183},{},[184],{"type":24,"value":185},"模型结构",{"type":18,"tag":26,"props":187,"children":188},{},[189,193,195,200,202,207],{"type":18,"tag":45,"props":190,"children":191},{},[192],{"type":24,"value":70},{"type":24,"value":194}," 通常基于DDPM、DDIM等扩散模型，核心设计围绕",{"type":18,"tag":45,"props":196,"children":197},{},[198],{"type":24,"value":199},"共享架构",{"type":24,"value":201},"与",{"type":18,"tag":45,"props":203,"children":204},{},[205],{"type":24,"value":206},"双路径训练",{"type":24,"value":208},"展开：",{"type":18,"tag":26,"props":210,"children":211},{},[212],{"type":18,"tag":45,"props":213,"children":214},{},[215],{"type":24,"value":216},"1、共享网络架构",{"type":18,"tag":26,"props":218,"children":219},{},[220],{"type":24,"value":221},"模型共享UNet结构，通过条件标记区分输入，负责预测噪声或数据分布梯度。",{"type":18,"tag":26,"props":223,"children":224},{},[225],{"type":24,"value":226},"使用Transformer或Embedding层将条件信息（如文本、标签）编码为向量，与时间步嵌入融合后输入网络。",{"type":18,"tag":26,"props":228,"children":229},{},[230],{"type":18,"tag":45,"props":231,"children":232},{},[233],{"type":24,"value":234},"2、双路径训练机制",{"type":18,"tag":26,"props":236,"children":237},{},[238],{"type":24,"value":239},"在条件训练路径中，输入条件编码与噪声图像，再训练模型预测条件去噪目标。",{"type":18,"tag":26,"props":241,"children":242},{},[243],{"type":24,"value":244},"在无条件训练路径中，会以一定概率丢弃条件信息，训练模型预测无条件去噪目标。此设计使模型同时学习条件依赖与无条件生成能力。",{"type":18,"tag":26,"props":246,"children":247},{},[248],{"type":18,"tag":45,"props":249,"children":250},{},[251],{"type":24,"value":252},"3、线性插值策略",{"type":18,"tag":26,"props":254,"children":255},{},[256],{"type":24,"value":257},"在推理阶段，通过超参数w对无条件预测与有条件预测进行线性组合，实现逼真性与多样性的权衡引导。",{"type":18,"tag":26,"props":259,"children":260},{},[261],{"type":24,"value":262},"当w=0时，仅生成无条件预测；当w=1时，仅生成有条件预测；当w>1时，增强条件匹配度；当w\u003C1时，提升生成多样性。",{"type":18,"tag":26,"props":264,"children":265},{},[266],{"type":18,"tag":45,"props":267,"children":268},{},[269],{"type":24,"value":270},"# 05",{"type":18,"tag":26,"props":272,"children":273},{},[274],{"type":18,"tag":45,"props":275,"children":276},{},[277],{"type":24,"value":278},"MindSpo****re代码实现",{"type":18,"tag":280,"props":281,"children":283},"pre",{"code":282},"\nimport mindspore as ms\nfrom mindspore import nn, ops, Tensor\nimport numpy as np\n\n# 设置随机种子确保可重复性\nms.set_seed(42)\n\n# 定义U-Net扩散模型\nclass UNet(nn.Cell):\n    def __init__(self, in_channels=3, out_channels=3, channels=128, time_dim=256):\n        super().__init__()\n        # 时间步嵌入\n        self.time_mlp = nn.SequentialCell(\n            nn.Dense(time_dim, time_dim * 2),\n            nn.SiLU(),\n            nn.Dense(time_dim * 2, time_dim)\n        )\n       \n        # 编码器\n        self.enc1 = nn.SequentialCell(\n            nn.Conv2d(in_channels + time_dim, channels, 3, padding=1),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.enc2 = nn.SequentialCell(\n            nn.Conv2d(channels, channels * 2, 3, padding=1, stride=2),\n            nn.GroupNorm(8, channels * 2),\n            nn.SiLU()\n        )\n       \n        # 中间层\n        self.mid = nn.SequentialCell(\n            nn.Conv2d(channels * 2, channels * 2, 3, padding=1),\n            nn.GroupNorm(8, channels * 2),\n            nn.SiLU()\n        )\n        \n        # 解码器\n        self.dec1 = nn.SequentialCell(\n            nn.Conv2dTranspose(channels * 2, channels, 3, padding=1, stride=2),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.dec2 = nn.SequentialCell(\n            nn.Conv2d(channels + time_dim, channels, 3, padding=1),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.final_conv = nn.Conv2d(channels, out_channels, 3, padding=1)\n    \n    def construct(self, x, t):\n        # 时间步嵌入\n        t_emb = self.time_mlp(t)\n        t_emb = ops.tile(t_emb.unsqueeze(1).unsqueeze(2), (1, x.shape[1], x.shape[2], 1)).transpose(0, 3, 1, 2)\n       \n        # 编码\n        x = ops.concat([x, t_emb], axis=1)\n        x = self.enc1(x)\n        x = self.enc2(x)\n        \n        # 中间层\n        x = self.mid(x)\n        \n        # 解码\n        x = self.dec1(x)\n        x = ops.concat([x, t_emb], axis=1)\n        x = self.dec2(x)\n        \n        # 输出噪声预测\n        return self.final_conv(x)\n\n# 定义Classifier-Free Guidance训练器\nclass CFGDiffusion(nn.Cell):\n    def __init__(self, model, guidance_scale=5.0):\n        super().__init__()\n        self.model = model\n        self.guidance_scale = guidance_scale\n        self.loss_fn = nn.MSELoss()\n        self.beta_schedule = self._linear_beta_schedule(timesteps=1000)\n        self.alpha_cumprod = np.cumprod(1 - self.beta_schedule)\n    \n    def _linear_beta_schedule(self, timesteps):\n        beta_start = 1e-4\n        beta_end = 0.02\n        return np.linspace(beta_start, beta_end, timesteps)\n   \n    def construct(self, x0, t, noise=None, y=None):\n        # 添加噪声\n        if noise is None:\n            noise = ops.randn_like(x0)\n        \n        # 前向过程\n        sqrt_alpha_cumprod = ops.sqrt(Tensor(self.alpha_cumprod[t], ms.float32))\n        sqrt_one_minus_alpha_cumprod = ops.sqrt(1 - Tensor(self.alpha_cumprod[t], ms.float32))\n        xt = sqrt_alpha_cumprod * x0 + sqrt_one_minus_alpha_cumprod * noise\n        \n        # 条件丢弃（50%概率）\n        if y is not None and ops.rand(1) > 0.5:\n            y = None\n       \n        # 预测噪声\n        if y is None:\n            pred_noise = self.model(xt, t)\n        else:\n            pred_noise_cond = self.model(xt, t)\n            pred_noise_uncond = self.model(xt, t)  # 无条件路径复用模型\n            pred_noise = pred_noise_uncond + self.guidance_scale * (pred_noise_cond - pred_noise_uncond)\n        \n        # 计算损失\n        loss = self.loss_fn(pred_noise, noise)\n        return loss\n",[284],{"type":18,"tag":285,"props":286,"children":287},"code",{"__ignoreMap":7},[288],{"type":24,"value":282},{"type":18,"tag":290,"props":291,"children":293},"h3",{"id":292},"参考链接",[294],{"type":18,"tag":45,"props":295,"children":296},{},[297],{"type":24,"value":292},{"type":18,"tag":26,"props":299,"children":300},{},[301,303],{"type":24,"value":302},"[1] 论文地址：",{"type":18,"tag":304,"props":305,"children":309},"a",{"href":306,"rel":307},"https://arxiv.org/pdf/2207.12598",[308],"nofollow",[310],{"type":24,"value":306},{"title":7,"searchDepth":312,"depth":312,"links":313},4,[314],{"id":292,"depth":315,"text":292},3,"markdown","content:technology-blogs:zh:3764.md","content","technology-blogs/zh/3764.md","technology-blogs/zh/3764","md",1776506134754]