mindscience.diffuser.DiffusionTrainer
- class mindscience.diffuser.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1.0, loss_type='l1')[源代码]
扩散训练器基类。
- 参数:
model (nn.Cell) - 扩散模型的主干网络。
scheduler (DiffusionScheduler) - 扩散调度器,可为 DDPM 或 DDIM 调度器。
objective (str, 可选) - 调度器函数的预测目标类型,支持以下类型:
"pred_noise"(预测扩散过程中的噪声)、"pred_x0"(预测原始样本)或"pred_v"(参见 Imagen Video 论文第 2.4 节)。默认"pred_noise"。p2_loss_weight_gamma (float, 可选) - p2 损失权重中的 gamma 参数,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认
0.。p2_loss_weight_k (float, 可选) - p2 损失权重中的 k 参数,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认
1.0。loss_type (str, 可选) - 损失函数类型。支持以下类型:
"l1"或"l2"。默认"l1"。
- 异常:
TypeError - 如果 scheduler 不是 DiffusionScheduler 类型时抛出。
样例:
>>> from mindspore import ops, dtype as mstype >>> from mindscience.diffuser import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer >>> # init params >>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100 >>> original_samples = ops.randn([batch_size, seq_len, in_dim]) >>> noise = ops.randn([batch_size, seq_len, in_dim]) >>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1]) >>> cond = ops.randn([batch_size, cond_dim]) >>> # init model and scheduler >>> net = ConditionDiffusionTransformer(in_channels=in_dim, ... out_channels=in_dim, ... cond_channels=cond_dim, ... hidden_channels=hidden_dim, ... layers=layers, ... heads=heads, ... time_token_cond=True, ... compute_dtype=mstype.float32) >>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, ... beta_start=0.0001, ... beta_end=0.02, ... beta_schedule="squaredcos_cap_v2", ... clip_sample=True, ... clip_sample_range=1.0, ... thresholding=False, ... dynamic_thresholding_ratio=0.995, ... rescale_betas_zero_snr=False, ... timestep_spacing="leading", ... compute_dtype=mstype.float32) >>> # init trainer >>> trainer = DiffusionTrainer(net, ... scheduler, ... objective='pred_noise', ... p2_loss_weight_gamma=0, ... p2_loss_weight_k=1, ... loss_type='l2') >>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)