mindflow.cell.unet2d 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
unet2d
"""
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import operations as P

from ..utils.check_func import check_param_type


class DoubleConv(nn.Cell):
    """double conv"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.SequentialCell(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def construct(self, x):
        """forward"""
        return self.double_conv(x)


class Down(nn.Cell):
    """down"""

    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)

    def construct(self, x):
        """forward"""
        x = self.maxpool(x)
        return self.conv(x)


class Up(nn.Cell):
    """up"""

    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super().__init__()
        self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride)
        self.conv = DoubleConv(in_channels, out_channels)
        self.cat = ops.Concat(axis=1)

    def construct(self, x1, x2):
        """forward"""
        x1 = self.up(x1)

        _, _, h1, w1 = ops.shape(x1)
        _, _, h2, w2 = ops.shape(x2)

        diff_y = w2 - w1
        diff_x = h2 - h1

        x1 = ops.Pad(((0, 0), (0, 0), (diff_x // 2, diff_x - diff_x // 2), (diff_y // 2, diff_y - diff_y // 2)))(x1)
        x = self.cat((x2, x1))
        return self.conv(x)


[文档]class UNet2D(nn.Cell): r""" The 2-dimensional U-Net model. U-Net is a U-shaped convolutional neural network for biomedical image segmentation. It has a contracting path that captures context and an expansive path that enables precise localization. The details can be found in `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. base_channels (int): The number of base channels of UNet2D. data_format (str): The format of input data. Default: 'NHWC' kernel_size (int): Specifies the height and width of the 2D convolution kernel. Default: 2. stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents the height and width of movement are both stride, or a tuple of two int numbers that represent height and width of movement respectively. Default: 2. Inputs: - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, resolution, channels)`. Outputs: - **output** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, resolution, channels)`. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import mindspore as ms >>> from mindspore import Tensor >>> import mindspore.common.dtype as mstype >>> import mindflow >>> from mindflow.cell import Unet2D >>> ms.set_context(mode=ms.GRAPH_MODE, save_graphs=False, device_target="GPU") >>> x=Tensor(np.ones([2, 128, 128, 3]), mstype.float32) >>> unet = Unet2D(in_channels=3, out_channels=3, base_channels=3) >>> output = unet(x) >>> print(output.shape) (2, 128, 128, 3) """ def __init__(self, in_channels, out_channels, base_channels, data_format="NHWC", kernel_size=2, stride=2): super().__init__() check_param_type(in_channels, "in_channels", data_type=int, exclude_type=bool) check_param_type(out_channels, "out_channels", data_type=int, exclude_type=bool) check_param_type(base_channels, "base_channels", data_type=int, exclude_type=bool) check_param_type(data_format, "data_format", data_type=str, exclude_type=bool) if data_format not in ("NHWC", "NCHW"): raise ValueError( "data_format must be 'NHWC' or 'NCHW', but got data_format: {}".format(data_format)) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.base_channels = base_channels self.data_format = data_format self.inc = DoubleConv(self.in_channels, self.base_channels, mid_channels=None) self.down1 = Down(self.base_channels, self.base_channels * 2, self.kernel_size, self.stride) self.down2 = Down(self.base_channels * 2, self.base_channels * 4, self.kernel_size, self.stride) self.down3 = Down(self.base_channels * 4, self.base_channels * 8, self.kernel_size, self.stride) self.down4 = Down(self.base_channels * 8, self.base_channels * 16, self.kernel_size, self.stride) self.up1 = Up(self.base_channels * 16, self.base_channels * 8, self.kernel_size, self.stride) self.up2 = Up(self.base_channels * 8, self.base_channels * 4, self.kernel_size, self.stride) self.up3 = Up(self.base_channels * 4, self.base_channels * 2, self.kernel_size, self.stride) self.up4 = Up(self.base_channels * 2, self.base_channels, self.kernel_size, self.stride) self.outc = nn.Conv2d(self.base_channels + self.in_channels, self.out_channels, kernel_size=3, stride=1) self.transpose = P.Transpose() self.cat = P.Concat(axis=1) def construct(self, x): """forward""" if self.data_format == "NHWC": x0 = self.transpose(x, (0, 3, 1, 2)) else: x0 = x x1 = self.inc(x0) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) x = self.cat((x, x0)) x = self.outc(x) if self.data_format == "NHWC": out = self.transpose(x, (0, 2, 3, 1)) else: out = x return out