mindscience.diffuser.DiffusionTrainer

class mindscience.diffuser.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1.0, loss_type='l1')[source]

Diffusion Trainer base class.

Parameters
Raises

TypeError – If scheduler is not DiffusionScheduler type.

Examples

>>> from mindspore import ops, dtype as mstype
>>> from mindscience.diffuser import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer
>>> # init params
>>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100
>>> original_samples = ops.randn([batch_size, seq_len, in_dim])
>>> noise = ops.randn([batch_size, seq_len, in_dim])
>>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1])
>>> cond = ops.randn([batch_size, cond_dim])
>>> # init model and scheduler
>>> net = ConditionDiffusionTransformer(in_channels=in_dim,
...                                     out_channels=in_dim,
...                                     cond_channels=cond_dim,
...                                     hidden_channels=hidden_dim,
...                                     layers=layers,
...                                     heads=heads,
...                                     time_token_cond=True,
...                                     compute_dtype=mstype.float32)
>>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps,
...                           beta_start=0.0001,
...                           beta_end=0.02,
...                           beta_schedule="squaredcos_cap_v2",
...                           clip_sample=True,
...                           clip_sample_range=1.0,
...                           thresholding=False,
...                           dynamic_thresholding_ratio=0.995,
...                           rescale_betas_zero_snr=False,
...                           timestep_spacing="leading",
...                           compute_dtype=mstype.float32)
>>> # init trainer
>>> trainer = DiffusionTrainer(net,
...                            scheduler,
...                            objective='pred_noise',
...                            p2_loss_weight_gamma=0,
...                            p2_loss_weight_k=1,
...                            loss_type='l2')
>>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)
get_loss(original_samples, noise, timesteps, condition=None)[source]

Calculate the forward loss of diffusion process.

Parameters
  • original_samples (Tensor) – The direct output from learned diffusion model.

  • noise (Tensor) – A current instance of a noise sample created by the diffusion process.

  • timesteps (Tensor) – The current discrete timestep in the diffusion chain.

  • condition (Tensor, optional) – The condition for desired outputs. Default: None.

Returns

Tensor, the model forward loss.