[{"data":1,"prerenderedAt":356},["ShallowReactive",2],{"content-query-2xU2rtJKjq":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":350,"_id":351,"_source":352,"_file":353,"_stem":354,"_extension":355},"/technology-blogs/zh/3751","zh",false,"","扩散模型系列——CFG","Classifier-Free Guidance   论文地址：Classifier-Free Diffusion Guidance    代码地址：https://github.com/teapearce/conditional_diffusion_mnist","2025-05-09","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2025/06/06/103c08dd352648d8aef9e5fab3e9dadc.png","technology-blogs","开发者说",{"type":15,"children":16,"toc":340},"root",[17,25,32,47,58,64,69,74,94,100,105,181,187,256,262,267,324,330],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"扩散模型系列cfg",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":29},"h2",{"id":28},"classifier-free-guidance",[30],{"type":24,"value":31},"Classifier-Free Guidance",{"type":18,"tag":33,"props":34,"children":35},"p",{},[36,38],{"type":24,"value":37},"论文地址：",{"type":18,"tag":39,"props":40,"children":44},"a",{"href":41,"rel":42},"https://arxiv.org/pdf/2207.12598",[43],"nofollow",[45],{"type":24,"value":46},"Classifier-Free Diffusion Guidance",{"type":18,"tag":33,"props":48,"children":49},{},[50,52],{"type":24,"value":51},"代码地址：",{"type":18,"tag":39,"props":53,"children":56},{"href":54,"rel":55},"https://github.com/teapearce/conditional_diffusion_mnist",[43],[57],{"type":24,"value":54},{"type":18,"tag":26,"props":59,"children":61},{"id":60},"一概述",[62],{"type":24,"value":63},"一、概述",{"type":18,"tag":33,"props":65,"children":66},{},[67],{"type":24,"value":68},"Classifier-Free Guidance（无分类器引导，简称 CFG） 是谷歌于 2022 年提出的扩散模型优化技术，旨在增强生成样本的质量与条件契合度。该方法通过联合训练无条件与有条件生成模型，避免了传统Classifier Guidance中显式分类器的依赖，无需额外训练噪声图像分类器。推理阶段，通过线性组合两种生成模式的预测结果，实现灵活的条件引导，显著提升文本生成图像、图像修复等任务的性能。值得强调的是，CFG 是一种推理优化技术，不改变模型训练目标。",{"type":18,"tag":33,"props":70,"children":71},{},[72],{"type":24,"value":73},"Classifier-Free Guidance的主要特点包括：",{"type":18,"tag":75,"props":76,"children":77},"ol",{},[78,84,89],{"type":18,"tag":79,"props":80,"children":81},"li",{},[82],{"type":24,"value":83},"无分类器依赖：无需额外训练和维护分类器，减少了模型复杂性和计算资源消耗。",{"type":18,"tag":79,"props":85,"children":86},{},[87],{"type":24,"value":88},"联合训练策略：通过随机丢弃条件信息，使同一网络同时学习无条件与有条件生成；",{"type":18,"tag":79,"props":90,"children":91},{},[92],{"type":24,"value":93},"灵活的引导强度：引入引导强度超参数 w，可灵活调节生成结果对条件的依赖程度。",{"type":18,"tag":26,"props":95,"children":97},{"id":96},"二主要步骤",[98],{"type":24,"value":99},"二、主要步骤",{"type":18,"tag":33,"props":101,"children":102},{},[103],{"type":24,"value":104},"CFG的实现分为训练阶段与推理阶段，通过巧妙的参数共享与结果融合达成高效引导。",{"type":18,"tag":75,"props":106,"children":107},{},[108,153],{"type":18,"tag":79,"props":109,"children":110},{},[111,113,122,126,128,131,133,136,138,141,143,146,148,151],{"type":24,"value":112},"训练阶段",{"type":18,"tag":114,"props":115,"children":116},"ul",{},[117],{"type":18,"tag":79,"props":118,"children":119},{},[120],{"type":24,"value":121},"无条件样本生成：以概率 p_\\theta随机置空条件输入 z，训练模型生成无约束样本。",{"type":18,"tag":123,"props":124,"children":125},"br",{},[],{"type":24,"value":127},"\\max_{\\theta} \\mathbb{E}_{x \\sim p_\\theta(z)} \\left[ \\log p_\\theta(z) \\right]",{"type":18,"tag":123,"props":129,"children":130},{},[],{"type":24,"value":132},"其中，p_\\theta(z)是无条件生成分布。- 有条件样本生成：同时使用完整条件 z训练模型生成符合条件的样本。",{"type":18,"tag":123,"props":134,"children":135},{},[],{"type":24,"value":137},"\\max_{\\theta} \\mathbb{E}_{x \\sim p_\\theta(z)} \\left[ \\log p_\\theta(z|c) - \\log p_\\theta(z) \\right]",{"type":18,"tag":123,"props":139,"children":140},{},[],{"type":24,"value":142},"其中，p_\\theta(z|c)是有条件生成分布，p_\\theta(z)是无条件生成分布。- 共享参数：无条件与有条件模型共享大部分参数，仅通过条件标记区分输入。",{"type":18,"tag":123,"props":144,"children":145},{},[],{"type":24,"value":147},"\\max_{\\theta} \\mathbb{E}_{x \\sim p_\\theta(z)} \\left[ \\log p_\\theta(z|c) - \\log p_\\theta(z) \\right] + \\lambda \\mathbb{E}_{x \\sim p_\\theta(z)} \\left[ \\log p_\\theta(z) \\right]",{"type":18,"tag":123,"props":149,"children":150},{},[],{"type":24,"value":152},"其中，\\lambda是权重参数，用于控制条件匹配度和无条件生成分布的熵之间的平衡。",{"type":18,"tag":79,"props":154,"children":155},{},[156,158],{"type":24,"value":157},"推理阶段",{"type":18,"tag":114,"props":159,"children":160},{},[161,166,176],{"type":18,"tag":79,"props":162,"children":163},{},[164],{"type":24,"value":165},"并行生成：同时生成无条件预测 p_\\theta(z)和有条件预测 p_\\theta(z|c)。",{"type":18,"tag":79,"props":167,"children":168},{},[169,171,174],{"type":24,"value":170},"线性插值：通过超参数 w对两者进行线性组合：",{"type":18,"tag":123,"props":172,"children":173},{},[],{"type":24,"value":175},"\\text{最终输出} = p_\\theta(z) + w \\cdot \\left[ p_\\theta(z|c) - p_\\theta(z) \\right]",{"type":18,"tag":79,"props":177,"children":178},{},[179],{"type":24,"value":180},"动态调整：w控制条件依赖强度，w=0时为纯无条件生成，w \\to \\infty时强制匹配条件。",{"type":18,"tag":26,"props":182,"children":184},{"id":183},"三关键点",[185],{"type":24,"value":186},"三、关键点",{"type":18,"tag":75,"props":188,"children":189},{},[190,203,238],{"type":18,"tag":79,"props":191,"children":192},{},[193,195],{"type":24,"value":194},"无需分类器设计",{"type":18,"tag":114,"props":196,"children":197},{},[198],{"type":18,"tag":79,"props":199,"children":200},{},[201],{"type":24,"value":202},"传统方法需额外训练分类器 p(c|z)，而CFG通过联合训练无条件生成模型和有条件生成模型，实现了条件信息的隐式建模。",{"type":18,"tag":79,"props":204,"children":205},{},[206,208],{"type":24,"value":207},"梯度组合策略",{"type":18,"tag":114,"props":209,"children":210},{},[211],{"type":18,"tag":79,"props":212,"children":213},{},[214,216,219,220],{"type":24,"value":215},"提出了基于引导强度参数w的梯度组合公式，实现了无条件生成梯度与有条件生成梯度的动态融合，本质上是对预测结果的加权差值放大。公式为：",{"type":18,"tag":123,"props":217,"children":218},{},[],{"type":24,"value":175},{"type":18,"tag":114,"props":221,"children":222},{},[223,228,233],{"type":18,"tag":79,"props":224,"children":225},{},[226],{"type":24,"value":227},"w=1：标准条件生成（如文本对齐图像）。",{"type":18,"tag":79,"props":229,"children":230},{},[231],{"type":24,"value":232},"w>1：增强条件契合度（如更鲜艳的颜色）。",{"type":18,"tag":79,"props":234,"children":235},{},[236],{"type":24,"value":237},"w\u003C1：提升多样性（如抽象艺术风格）。",{"type":18,"tag":79,"props":239,"children":240},{},[241,243],{"type":24,"value":242},"联合训练范式",{"type":18,"tag":114,"props":244,"children":245},{},[246,251],{"type":18,"tag":79,"props":247,"children":248},{},[249],{"type":24,"value":250},"随机条件丢弃：训练时以50%概率随机屏蔽条件输入，强制模型学习数据分布的共性与条件特化；",{"type":18,"tag":79,"props":252,"children":253},{},[254],{"type":24,"value":255},"参数共享：无条件与有条件生成共享U-Net网络参数，仅通过条件标记区分输入，避免冗余训练。",{"type":18,"tag":26,"props":257,"children":259},{"id":258},"四模型结构",[260],{"type":24,"value":261},"四、模型结构",{"type":18,"tag":33,"props":263,"children":264},{},[265],{"type":24,"value":266},"Classifier-Free Guidance通常基于DDPM、DDIM等扩散模型，核心设计围绕共享架构与双路径训练展开：",{"type":18,"tag":75,"props":268,"children":269},{},[270,288,306],{"type":18,"tag":79,"props":271,"children":272},{},[273,275],{"type":24,"value":274},"共享网络架构",{"type":18,"tag":114,"props":276,"children":277},{},[278,283],{"type":18,"tag":79,"props":279,"children":280},{},[281],{"type":24,"value":282},"模型共享UNet结构，通过条件标记区分输入，负责预测噪声或数据分布梯度。",{"type":18,"tag":79,"props":284,"children":285},{},[286],{"type":24,"value":287},"使用Transformer或Embedding层将条件信息（如文本、标签）编码为向量，与时间步嵌入融合后输入网络。",{"type":18,"tag":79,"props":289,"children":290},{},[291,293],{"type":24,"value":292},"双路径训练机制",{"type":18,"tag":114,"props":294,"children":295},{},[296,301],{"type":18,"tag":79,"props":297,"children":298},{},[299],{"type":24,"value":300},"在条件训练路径中，输入条件编码与噪声图像，再训练模型预测条件去噪目标。",{"type":18,"tag":79,"props":302,"children":303},{},[304],{"type":24,"value":305},"在无条件训练路径中，会以一定概率丢弃条件信息，训练模型预测无条件去噪目标。此设计使模型同时学习条件依赖与无条件生成能力。",{"type":18,"tag":79,"props":307,"children":308},{},[309,311],{"type":24,"value":310},"线性插值策略",{"type":18,"tag":114,"props":312,"children":313},{},[314,319],{"type":18,"tag":79,"props":315,"children":316},{},[317],{"type":24,"value":318},"在推理阶段，通过超参数w对无条件预测与有条件预测进行线性组合，实现逼真性与多样性的权衡引导。",{"type":18,"tag":79,"props":320,"children":321},{},[322],{"type":24,"value":323},"当w=0时，仅生成无条件预测；当w=1时，仅生成有条件预测；当w>1时，增强条件匹配度；当w\u003C1时，提升生成多样性。",{"type":18,"tag":26,"props":325,"children":327},{"id":326},"五代码实现",[328],{"type":24,"value":329},"五、代码实现",{"type":18,"tag":331,"props":332,"children":334},"pre",{"code":333},"import mindspore as ms\nfrom mindspore import nn, ops, Tensor\nimport numpy as np\n\n# 设置随机种子确保可重复性\nms.set_seed(42)\n\n# 定义U-Net扩散模型\nclass UNet(nn.Cell):\n    def __init__(self, in_channels=3, out_channels=3, channels=128, time_dim=256):\n        super().__init__()\n        # 时间步嵌入\n        self.time_mlp = nn.SequentialCell(\n            nn.Dense(time_dim, time_dim * 2),\n            nn.SiLU(),\n            nn.Dense(time_dim * 2, time_dim)\n        )\n    \n        # 编码器\n        self.enc1 = nn.SequentialCell(\n            nn.Conv2d(in_channels + time_dim, channels, 3, padding=1),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.enc2 = nn.SequentialCell(\n            nn.Conv2d(channels, channels * 2, 3, padding=1, stride=2),\n            nn.GroupNorm(8, channels * 2),\n            nn.SiLU()\n        )\n    \n        # 中间层\n        self.mid = nn.SequentialCell(\n            nn.Conv2d(channels * 2, channels * 2, 3, padding=1),\n            nn.GroupNorm(8, channels * 2),\n            nn.SiLU()\n        )\n    \n        # 解码器\n        self.dec1 = nn.SequentialCell(\n            nn.Conv2dTranspose(channels * 2, channels, 3, padding=1, stride=2),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.dec2 = nn.SequentialCell(\n            nn.Conv2d(channels + time_dim, channels, 3, padding=1),\n            nn.GroupNorm(4, channels),\n            nn.SiLU()\n        )\n        self.final_conv = nn.Conv2d(channels, out_channels, 3, padding=1)\n\n    def construct(self, x, t):\n        # 时间步嵌入\n        t_emb = self.time_mlp(t)\n        t_emb = ops.tile(t_emb.unsqueeze(1).unsqueeze(2), (1, x.shape[1], x.shape[2], 1)).transpose(0, 3, 1, 2)\n    \n        # 编码\n        x = ops.concat([x, t_emb], axis=1)\n        x = self.enc1(x)\n        x = self.enc2(x)\n    \n        # 中间层\n        x = self.mid(x)\n    \n        # 解码\n        x = self.dec1(x)\n        x = ops.concat([x, t_emb], axis=1)\n        x = self.dec2(x)\n    \n        # 输出噪声预测\n        return self.final_conv(x)\n\n# 定义Classifier-Free Guidance训练器\nclass CFGDiffusion(nn.Cell):\n    def __init__(self, model, guidance_scale=5.0):\n        super().__init__()\n        self.model = model\n        self.guidance_scale = guidance_scale\n        self.loss_fn = nn.MSELoss()\n        self.beta_schedule = self._linear_beta_schedule(timesteps=1000)\n        self.alpha_cumprod = np.cumprod(1 - self.beta_schedule)\n\n    def _linear_beta_schedule(self, timesteps):\n        beta_start = 1e-4\n        beta_end = 0.02\n        return np.linspace(beta_start, beta_end, timesteps)\n\n    def construct(self, x0, t, noise=None, y=None):\n        # 添加噪声\n        if noise is None:\n            noise = ops.randn_like(x0)\n    \n        # 前向过程\n        sqrt_alpha_cumprod = ops.sqrt(Tensor(self.alpha_cumprod[t], ms.float32))\n        sqrt_one_minus_alpha_cumprod = ops.sqrt(1 - Tensor(self.alpha_cumprod[t], ms.float32))\n        xt = sqrt_alpha_cumprod * x0 + sqrt_one_minus_alpha_cumprod * noise\n    \n        # 条件丢弃（50%概率）\n        if y is not None and ops.rand(1) > 0.5:\n            y = None\n    \n        # 预测噪声\n        if y is None:\n            pred_noise = self.model(xt, t)\n        else:\n            pred_noise_cond = self.model(xt, t)\n            pred_noise_uncond = self.model(xt, t)  # 无条件路径复用模型\n            pred_noise = pred_noise_uncond + self.guidance_scale * (pred_noise_cond - pred_noise_uncond)\n    \n        # 计算损失\n        loss = self.loss_fn(pred_noise, noise)\n        return loss\n",[335],{"type":18,"tag":336,"props":337,"children":338},"code",{"__ignoreMap":7},[339],{"type":24,"value":333},{"title":7,"searchDepth":341,"depth":341,"links":342},4,[343,345,346,347,348,349],{"id":28,"depth":344,"text":31},2,{"id":60,"depth":344,"text":63},{"id":96,"depth":344,"text":99},{"id":183,"depth":344,"text":186},{"id":258,"depth":344,"text":261},{"id":326,"depth":344,"text":329},"markdown","content:technology-blogs:zh:3751.md","content","technology-blogs/zh/3751.md","technology-blogs/zh/3751","md",1776506134497]