mindscience.models.layers.UNet2D
- class mindscience.models.layers.UNet2D(in_channels, out_channels, base_channels, n_layers=4, data_format='NHWC', kernel_size=2, stride=2, activation='relu', enable_bn=True)[source]
The 2-dimensional U-Net model. U-Net is a U-shaped convolutional neural network for biomedical image segmentation. It has a contracting path that captures context and an expansive path that enables precise localization. The details can be found in U-Net: Convolutional Networks for Biomedical Image Segmentation .
- Parameters
in_channels (int) – The number of input channels.
out_channels (int) – The number of output channels.
base_channels (int) – The number of base channels of UNet2D.
n_layers (int, optional) – The number of downsample and upsample convolutions. Default:
4.data_format (str, optional) – The format of input data. Default:
"NHWC"kernel_size (int, optional) – Specifies the height and width of the 2D convolution kernel. Default:
2.stride (Union[int, tuple[int]], optional) – The distance of kernel moving, an int number that represents the height and width of movement are both stride, or a tuple of two int numbers that represent height and width of movement respectively. Default:
2.activation (Union[str, class], optional) – The activation function, could be either str or class. Default:
"relu".enable_bn (bool, optional) – Specifies whether to use batch norm in convolutions. Default:
True.
- Inputs:
x (Tensor) - Tensor of shape \((batch\_size, resolution, resolution, channels)\).
- Outputs:
output (Tensor) - Tensor of shape \((batch\_size, resolution, resolution, channels)\).
- Raises
ValueError – If data_format is not
'NHWC'or'NCHW'.ValueError – If n_layers is
0.
Examples
>>> import mindspore as ms >>> import numpy as np >>> from mindspore import Tensor >>> import mindspore.common.dtype as mstype >>> from mindscience.models.layers import UNet2D >>> ms.set_context(mode=ms.GRAPH_MODE, save_graphs=False) >>> x=Tensor(np.ones([2, 128, 128, 3]), mstype.float32) >>> unet = UNet2D(in_channels=3, out_channels=3, base_channels=3) >>> output = unet(x) >>> print(output.shape) (2, 128, 128, 3)