# mindspore.nn.probability.dpn.ConditionalVAE¶

class mindspore.nn.probability.dpn.ConditionalVAE(encoder, decoder, hidden_size, latent_size, num_classes)[source]

Conditional Variational Auto-Encoder (CVAE).

The difference with VAE is that CVAE uses labels information. For more details, refer to Learning Structured Output Representation using Deep Conditional Generative Models.

Note

When encoder and decoder ard defined, the shape of the encoder’s output tensor and decoder’s input tensor must be $$(N, hidden\_size)$$. The latent_size must be less than or equal to the hidden_size.

Parameters
• encoder (Cell) – The Deep Neural Network (DNN) model defined as encoder.

• decoder (Cell) – The DNN model defined as decoder.

• hidden_size (int) – The size of encoder’s output tensor.

• latent_size (int) – The size of the latent space.

• num_classes (int) – The number of classes.

Inputs:
• input_x (Tensor) - The shape of input tensor is $$(N, C, H, W)$$, which is the same as the input of encoder.

• input_y (Tensor) - The tensor of the target data, the shape is $$(N,)$$.

Outputs:
• output (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).

Supported Platforms:

Ascend GPU

construct(x, y)[source]

The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.

generate_sample(sample_y, generate_nums, shape)[source]

Randomly sample from the latent space to generate samples.

Parameters
• sample_y (Tensor) – Define the label of samples. Tensor of shape (generate_nums, ) and type mindspore.int32.

• generate_nums (int) – The number of samples to generate.

• shape (tuple) – The shape of sample, which must be the format of (generate_nums, C, H, W) or (-1, C, H, W).

Returns

Tensor, the generated samples.

reconstruct_sample(x, y)[source]

Reconstruct samples from original data.

Parameters
• x (Tensor) – The input tensor to be reconstructed, the shape is (N, C, H, W).

• y (Tensor) – The label of the input tensor, the shape is (N,).

Returns

Tensor, the reconstructed sample.