# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The ViT model
"""
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import dtype as mstype
from ..utils import to_2tuple
from .layer import Decoder, Encoder
[文档]class ViT(nn.Cell):
r"""
This module based on ViT backbone which including encoder, decoding_embedding, decoder and dense layer.
Args:
image_size (tuple[int]): The image size of input. Default: (192, 384).
in_channels (int): The input feature size of input. Default: 7.
out_channels (int): The output feature size of output. Default: 3.
patch_size (int): The patch size of image. Default: 16.
encoder_depths (int): The encoder depth of encoder layer. Default: 12.
encoder_embed_dim (int): The encoder embedding dimension of encoder layer. Default: 768.
encoder_num_heads (int): The encoder heads' number of encoder layer. Default: 12.
decoder_depths (int): The decoder depth of decoder layer. Default: 8.
decoder_embed_dim (int): The decoder embedding dimension of decoder layer. Default: 512.
decoder_num_heads (int): The decoder heads' number of decoder layer. Default: 16.
mlp_ratio (int): The rate of mlp layer. Default: 4.
dropout_rate (float): The rate of dropout layer. Default: 1.0.
compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer.
Default: mstype.float16.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`.
Outputs:
- **output** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`.
where patchify_size = (image_height * image_width) / (patch_size * patch_size)
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import context
>>> from mindspore import dtype as mstype
>>> from mindflow.cell import ViT
>>> input_tensor = Tensor(np.ones((32, 3, 192, 384)), mstype.float32)
>>> print(input_tensor.shape)
(32, 3, 192, 384)
>>> model = ViT(in_channels=3,
>>> out_channels=3,
>>> encoder_depths=6,
>>> encoder_embed_dim=768,
>>> encoder_num_heads=12,
>>> decoder_depths=6,
>>> decoder_embed_dim=512,
>>> decoder_num_heads=16,
>>> )
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
(32, 288, 768)
"""
def __init__(self,
image_size=(192, 384),
in_channels=7,
out_channels=3,
patch_size=16,
encoder_depths=12,
encoder_embed_dim=768,
encoder_num_heads=12,
decoder_depths=8,
decoder_embed_dim=512,
decoder_num_heads=16,
mlp_ratio=4,
dropout_rate=1.0,
compute_dtype=mstype.float16):
super(ViT, self).__init__()
image_size = to_2tuple(image_size)
grid_size = (image_size[0] // patch_size, image_size[1] // patch_size)
self.img_size = image_size
self.patch_size = patch_size
self.out_channels = out_channels
self.in_channels = in_channels
self.encoder_depths = encoder_depths
self.encoder_embed_dim = encoder_embed_dim
self.encoder_num_heads = encoder_num_heads
self.decoder_depths = decoder_depths
self.decoder_embed_dim = decoder_embed_dim
self.decoder_num_heads = decoder_num_heads
self.transpose = ops.Transpose()
self.encoder = Encoder(grid_size=grid_size,
in_channels=in_channels,
patch_size=patch_size,
depths=encoder_depths,
embed_dim=encoder_embed_dim,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
dropout_rate=dropout_rate,
compute_dtype=compute_dtype)
self.decoder_embedding = nn.Dense(encoder_embed_dim,
decoder_embed_dim,
has_bias=True,
weight_init="XavierUniform").to_float(compute_dtype)
self.decoder = Decoder(grid_size=grid_size,
depths=decoder_depths,
embed_dim=decoder_embed_dim,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
dropout_rate=dropout_rate,
compute_dtype=compute_dtype)
self.decoder_pred = nn.Dense(decoder_embed_dim,
patch_size ** 2 * out_channels,
has_bias=True,
weight_init="XavierUniform").to_float(compute_dtype)
def construct(self, x):
x = self.encoder(x)
x = self.decoder_embedding(x)
x = self.decoder(x)
images = self.decoder_pred(x)
images = P.Cast()(images, mstype.float32)
return images