[{"data":1,"prerenderedAt":585},["ShallowReactive",2],{"content-query-NsFeDz3Y3x":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":579,"_id":580,"_source":581,"_file":582,"_stem":583,"_extension":584},"/technology-blogs/zh/2173","zh",false,"","应用案例 | 不用学PS，属于程序员的修图方式来了！","基于此，论文（Contextual Residual Aggregation for Ultra High-Resolution Image Inpainting）提出了一种上下文残差聚合机制（CRA），将上下文聚合残差添加到上采样的神经网络修复结果来输出最终结果。","2023-02-16","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2023/02/22/8c2188593e504fce836218da3a8e9385.png","technology-blogs","大V博文",{"type":15,"children":16,"toc":561},"root",[17,25,31,36,45,50,60,69,74,79,84,89,94,103,111,116,124,129,137,142,150,157,165,170,180,185,190,197,205,213,218,226,236,241,248,253,258,265,270,278,285,293,302,308,314,321,329,334,339,347,355,360,368,376,381,386,393,398,406,411,419,427,432,440,448,456,461,469,474,479,487,495,500,508,513,521,528,536,541,554],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"应用案例-不用学ps属于程序员的修图方式来了",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":28},"p",{},[29],{"type":24,"value":30},"传统的图像修复方法只能处理低分辨率的输入图像，而对低分辨率修复结果进行简单的上采样只会产生大而模糊的结果。我们知道，在模糊图像上添加高频残差部分可以丰富图像的细节和纹理，基于此，论文（Contextual Residual Aggregation for Ultra High-Resolution Image Inpainting）提出了一种上下文残差聚合机制（CRA），将上下文聚合残差添加到上采样的神经网络修复结果来输出最终结果。",{"type":18,"tag":26,"props":32,"children":33},{},[34],{"type":24,"value":35},"通过注意力转移模块（ATM）从上下文残差和注意分数来计算掩模区域中的聚合残差，通过搭建生成对抗网络进行低分辨率的图像预测，很好地抑制了内存和计算时间的成本。此外，论文引入了一些其他技术来提高修复质量、计算速度，如：注意力分数共享、多尺度注意力转移机制、轻量级门控卷积（LWGC），最终，该模型可以高精度修复占有25%孔洞大小的大型图像（高达8K）。",{"type":18,"tag":26,"props":37,"children":38},{},[39],{"type":18,"tag":40,"props":41,"children":42},"strong",{},[43],{"type":24,"value":44},"配置环境",{"type":18,"tag":26,"props":46,"children":47},{},[48],{"type":24,"value":49},"本教程我们在GPU环境下，使用图模式运行实验。",{"type":18,"tag":51,"props":52,"children":54},"pre",{"code":53},"from mindspore import context\n\n#选择执行模式为图模式；指定训练使用的平台为\"GPU\"，如需使用昇腾硬件可将其替换为\"Ascend\"\ncontext.set_context(mode=context.GRAPH_MODE, device_target='GPU')\n",[55],{"type":18,"tag":56,"props":57,"children":58},"code",{"__ignoreMap":7},[59],{"type":24,"value":53},{"type":18,"tag":61,"props":62,"children":64},"h3",{"id":63},"准备数据",[65],{"type":18,"tag":40,"props":66,"children":67},{},[68],{"type":24,"value":63},{"type":18,"tag":26,"props":70,"children":71},{},[72],{"type":24,"value":73},"本案例使用places2数据集作为训练集 ,在该官网中下载High-resolution images训练数据集，该数据集共有443个场景类别，包含超过180万张1024x1024的图片。",{"type":18,"tag":26,"props":75,"children":76},{},[77],{"type":24,"value":78},"mask数据集共包含100张掩膜图片，可使用两种方法动态生成不规则mask，或模拟撕裂、划痕、斑点等，或通过随机操作真实的对象形状模板来生成掩膜。",{"type":18,"tag":26,"props":80,"children":81},{},[82],{"type":24,"value":83},"推理数据包含两组匹配的image图像和mask图像。",{"type":18,"tag":26,"props":85,"children":86},{},[87],{"type":24,"value":88},"同时，训练数据，包含16张图像，放到/examples目录，用于案例CRA.ipynb测试。",{"type":18,"tag":26,"props":90,"children":91},{},[92],{"type":24,"value":93},"如需下载上述数据集进行模型推理，**欢迎大家私信后台留言“数****据集”获得。**将解压后数据集放到CRA目录下，文件目录如下所示：",{"type":18,"tag":26,"props":95,"children":96},{},[97],{"type":18,"tag":98,"props":99,"children":102},"img",{"alt":100,"src":101},"image.png","https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222062452.26904676057094904116207762564378:50540221072327:2400:E4445C9E0422CCB965753A6115ACED5436D212E4FCB493B0ECC9A7A44F3275B2.png",[],{"type":18,"tag":61,"props":104,"children":106},{"id":105},"数据处理",[107],{"type":18,"tag":40,"props":108,"children":109},{},[110],{"type":24,"value":105},{"type":18,"tag":26,"props":112,"children":113},{},[114],{"type":24,"value":115},"对于places2数据集：定义InpaintDataset()类读取数据，并将图像随机裁剪到512x512大小，进行归一化处理。",{"type":18,"tag":51,"props":117,"children":119},{"code":118},"import os\nimport cv2\n\n\nclass InpaintDataset():\n    \"\"\"Process image dataset\"\"\"\n\n    def __init__(self, args):\n        self.args = args\n        self.imglist = self.get_files('./examples')\n\n    def get_files(self, path):\n        ret = []\n        for tuple_path in os.walk(path):\n            for filespath in tuple_path[2]:\n                ret.append(os.path.join(tuple_path[0], filespath))\n        return ret\n\n    def __len__(self):\n        return len(self.imglist)\n\n    def __getitem__(self, index):\n        img = cv2.imread(self.imglist[index])\n        h, w = self.args.IMG_SHAPE[0], self.args.IMG_SHAPE[1]\n        img = cv2.resize(img, (h, w))\n        img = img / 127.5 - 1\n        img = img.transpose((2, 0, 1))\n        return img\n",[120],{"type":18,"tag":56,"props":121,"children":122},{"__ignoreMap":7},[123],{"type":24,"value":118},{"type":18,"tag":26,"props":125,"children":126},{},[127],{"type":24,"value":128},"对于mask数据集：从数据集中随机选取mask图像，并进行随机水平翻转、旋转随机角度、随机缩放0.8~1.0倍一系列数据增强操作，输出[1, 1, 512, 512]大小的mask张量。",{"type":18,"tag":51,"props":130,"children":132},{"code":131},"import random\n\nimport mindspore\nimport mindspore.ops as ops\nimport mindspore.dataset as ds\nfrom mindspore import Tensor\n\nfrom src.process_dataset.mask import get_files, read_masks, random_rotate_image, random_resize_image\n\n\ndef random_mask(args):\n    \"\"\"Process mask dataset\"\"\"\n\n    img_shape = args.IMG_SHAPE\n    height = img_shape[0]\n    width = img_shape[1]\n    path_list, n_masks = get_files('./mask_templates')\n    nd = random.randint(0, n_masks - 1)\n    path_mask = path_list[nd]\n    mask = read_masks(path_mask)\n    mask = ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5)(mask)\n    scale = random.uniform(0.8, 1.0)\n    mask = random_rotate_image(mask)\n    mask = random_resize_image(mask, scale, height, width)\n    crop = ds.vision.c_transforms.CenterCrop((height, width))\n    mask1 = crop(mask)\n    mask_show = mask1\n    mask2 = Tensor.from_numpy(mask1)\n    mask3 = mask2.astype(mindspore.float32)\n    mask4 = mask3[:, :, 0:1]\n    mask5 = ops.ExpandDims()(mask4, 0)\n    mask6 = ops.Mul()(1 / 255, mask5)\n    mask = ops.Reshape()(mask6, (1, height, width, 1))\n    mask = ops.Transpose()(mask, (0, 3, 1, 2))\n    return mask, mask_show\n",[133],{"type":18,"tag":56,"props":134,"children":135},{"__ignoreMap":7},[136],{"type":24,"value":131},{"type":18,"tag":26,"props":138,"children":139},{},[140],{"type":24,"value":141},"调用InpaintDataset和GeneratorDataset读取数据集，通过create_dict_iterator创建数据集迭代对象，将输入图像、mask掩膜图像以及待恢复图像进行可视化处理，部分训练数据展示如下：",{"type":18,"tag":51,"props":143,"children":145},{"code":144},"import numpy as np\nimport matplotlib.pyplot as plt\n\nfrom src.config.config import cra_config as config\n\n\ndataset_generator = InpaintDataset(config)\ndataset = ds.GeneratorDataset(dataset_generator, ['image'])\ndataset_size = len(dataset_generator)\ntotal_batch = dataset_size // config.train_batchsize\ndataset = dataset.batch(config.train_batchsize, drop_remainder=True)\ndataset = dataset.create_dict_iterator(output_numpy=True)\ndataset = next(dataset)\nfor i, image in enumerate(dataset['image']):\n    image = image[(2, 1, 0), :, :]\n    image = image.transpose(1, 2, 0)\n    mask, mask_show = random_mask(config)\n    mask = ops.Squeeze(0)(mask).asnumpy()\n    mask = mask.transpose(1, 2, 0)\n    real = image * (1-mask)\n    result = np.concatenate([image, mask_show, real], 1)\n    plt.subplot(8, 1, i+1)\n    plt.axis('off')\n    plt.imshow(result)\nplt.show()\n",[146],{"type":18,"tag":56,"props":147,"children":148},{"__ignoreMap":7},[149],{"type":24,"value":144},{"type":18,"tag":26,"props":151,"children":152},{},[153],{"type":18,"tag":98,"props":154,"children":156},{"alt":7,"src":155},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2023/02/22/d4e165110de141f28a59e731c6941eb4.png",[],{"type":18,"tag":26,"props":158,"children":159},{},[160],{"type":18,"tag":40,"props":161,"children":162},{},[163],{"type":24,"value":164},"模型架构",{"type":18,"tag":26,"props":166,"children":167},{},[168],{"type":24,"value":169},"在数据加载完成后，我们进行网络模型的整体搭建。具体来说，我们使用生成对抗网络来预测低分辨率图像修复结果，并对其进行上采样以产生跟待修复图像同样尺寸的模糊图像；通过聚合上下文patches的加权高频残差来生成缺失内容的高频信息；将聚合残差添加到大而模糊的图像中获得清晰修复图像。接下来，将从部分到整体介绍网络架构。",{"type":18,"tag":171,"props":172,"children":174},"h2",{"id":173},"轻量级门控卷积lwgc",[175],{"type":18,"tag":40,"props":176,"children":177},{},[178],{"type":24,"value":179},"轻量级门控卷积（LWGC）",{"type":18,"tag":26,"props":181,"children":182},{},[183],{"type":24,"value":184},"综合分析普通卷积和部分卷积对处理不规则空洞区域的缺陷，论文初步采用门控卷积（GC）来搭建模型各卷积层，然而，与普通卷积相比，GC的参数数量和处理时间几乎翻了一倍。因此，该论文提出了三个修改版本的轻量级门控卷积：depth-separable LWGC(LWGCds)、pixelwise LWGC(LWGCpw)、single-channel LWGC(LWGCsc)。",{"type":18,"tag":26,"props":186,"children":187},{},[188],{"type":24,"value":189},"原始GC的输出可以表示为：",{"type":18,"tag":26,"props":191,"children":192},{},[193],{"type":18,"tag":98,"props":194,"children":196},{"alt":100,"src":195},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222062717.49999952922329109863451680372944:50540221072327:2400:C13E57145FC34730911BC602A92D5570E7138BEF2B611D94EA4064451C800B76.png",[],{"type":18,"tag":26,"props":198,"children":199},{},[200],{"type":18,"tag":98,"props":201,"children":204},{"alt":202,"src":203},"cke_15011.png","https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222062757.78359817447392698933442593803387:50540221072327:2400:5A6F57EADEADB1178D19C2BF2AB4F3C9A6F5B0B17CFFC3445014173EB76BC3D4.png",[],{"type":18,"tag":51,"props":206,"children":208},{"code":207},"import mindspore.nn as nn\nfrom mindspore.common.initializer import TruncatedNormal\n\n\nclass ScConv(nn.Cell):\n    \"\"\"Build LWGCsc Gate branch\"\"\"\n\n    def __init__(self, in_channel, kernel_size, stride, padding, dilation):\n        super(ScConv, self).__init__()\n        self.single_channel_conv = nn.Conv2d(in_channels=in_channel, out_channels=1, kernel_size=kernel_size,\n                                             stride=stride, pad_mode='same', padding=padding, dilation=dilation,\n                                             group=1, has_bias=True, weight_init=TruncatedNormal(0.05))\n\n    def construct(self, x):\n        x = self.single_channel_conv(x)\n        return x\n",[209],{"type":18,"tag":56,"props":210,"children":211},{"__ignoreMap":7},[212],{"type":24,"value":207},{"type":18,"tag":26,"props":214,"children":215},{},[216],{"type":24,"value":217},"结合nn.Conv2d普通卷积搭建门控卷积网络层：",{"type":18,"tag":51,"props":219,"children":221},{"code":220},"class GatedConv2d(nn.Cell):\n    \"\"\"Build LWGCsc and LWGCds network layer\"\"\"\n\n    def __init__(self, in_channel, out_channel, kernel_size, stride, dilation, sc=False):\n        super(GatedConv2d, self).__init__()\n        self.activation = nn.ELU(alpha=1.0)\n        if sc:\n            self.conv2d = nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode='same', padding=0,\n                                    dilation=dilation, has_bias=True, weight_init=TruncatedNormal(0.05))\n            self.gate_factor = ScConv(in_channel, kernel_size, stride, 0, dilation)\n        else:\n            self.conv2d = nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode='same', padding=0,\n                                    dilation=dilation, has_bias=True, weight_init=TruncatedNormal(0.05))\n            self.gate_factor = DepthSeparableConv(in_channel, out_channel, stride, dilation)\n        self.sigmoid = nn.Sigmoid()\n\n    def construct(self, x):\n        gc_f = self.conv2d(x)\n        gc_g = self.gate_factor(x)\n        x = self.sigmoid(gc_g) * self.activation(gc_f)\n        return x\n",[222],{"type":18,"tag":56,"props":223,"children":224},{"__ignoreMap":7},[225],{"type":24,"value":220},{"type":18,"tag":227,"props":228,"children":230},"h4",{"id":229},"注意力计算模块acm",[231],{"type":18,"tag":40,"props":232,"children":233},{},[234],{"type":24,"value":235},"注意力计算模块（ACM）",{"type":18,"tag":26,"props":237,"children":238},{},[239],{"type":24,"value":240},"注意力分数是根据高级特征图（表示为P）的区域亲和性（region affinity）计算的，P被划分成特定大小的块，通过计算缺失区域内外块之间的余弦相似度来获取相似度分数，具体表示如下：",{"type":18,"tag":26,"props":242,"children":243},{},[244],{"type":18,"tag":98,"props":245,"children":247},{"alt":100,"src":246},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222062912.40162197427763543460368398049298:50540221072327:2400:9F6292AA02824B6AAD814ADF3FEA7038C0501A85B378FD522ACE51154C6A0697.png",[],{"type":18,"tag":26,"props":249,"children":250},{},[251],{"type":24,"value":252},"其中，pi是从P中hole外提取的第i个patch，pj是从P中hole内提取到的第j个patch。",{"type":18,"tag":26,"props":254,"children":255},{},[256],{"type":24,"value":257},"将softmax应用于相似度分数以获取P中每个patch的注意力分数：",{"type":18,"tag":26,"props":259,"children":260},{},[261],{"type":18,"tag":98,"props":262,"children":264},{"alt":100,"src":263},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222062925.18337660519841765670658816342930:50540221072327:2400:3601285CA36F37A59966287BD8196B6CDEF1EF8FF4526638ACE08D8C00AF00C0.png",[],{"type":18,"tag":26,"props":266,"children":267},{},[268],{"type":24,"value":269},"其中，N是P中hole区域外的patches个数。在我们的框架中，采用64×64的高级特征图计算注意力分数，并划分每个patch的大小为3×3，在张量correspondence中保存注意力分数。",{"type":18,"tag":51,"props":271,"children":273},{"code":272},"from src.models.compute_attention import downsample, InitConv2d\n\n\nclass ContextualAttention(nn.Cell):\n    \"\"\"\n    Attention score computing module.\n\n    Args:\n        softmax_scale(int): scaled softmax for attention.\n        src(Tensor): input feature to match (foreground).\n        ref(Tensor): input feature for match (background).\n        mask(Tensor): input mask for ref, indicating patches not available.\n\n    Return:\n        out: Foreground area filled with context information\n             (It generally refers to the 64 * 64 feature map used to calculate attention scores).\n        correspondence: Attention score.\n    \"\"\"\n\n    def __init__(self, softmax_scale=10, fuse=True, dtype=mindspore.float32):\n        super(ContextualAttention, self).__init__()\n        self.softmax_scale = softmax_scale\n        self.fuse = fuse\n        self.dtype = dtype\n        self.reducesum = ops.ReduceSum(False)\n        self.unfold1 = nn.Unfold([1, 3, 3, 1], [1, 2, 2, 1], [1, 1, 1, 1], 'same')\n        self.unfold2 = nn.Unfold([1, 3, 3, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'same')\n        self.transpose = ops.Transpose()\n        self.reshape = ops.Reshape()\n        self.pool1 = nn.MaxPool2d(16, 16, 'same', 'NCHW')\n        self.pool2 = nn.MaxPool2d(3, 1, 'same', 'NCHW')\n        self.maximum = ops.Maximum()\n        self.sqrt = ops.Sqrt()\n        self.square = ops.Square()\n        self.eye = ops.Eye()\n        self.reducemax = ops.ReduceMax(True)\n        self.greaterequal = ops.GreaterEqual()\n        self.pow = ops.Pow()\n        self.div = ops.Div()\n        self.softmax = nn.Softmax(1)\n        self.cat = ops.Concat(0)\n        self.conv1 = InitConv2d([3, 3, 128, 1024], 1, True)\n        self.conv2 = InitConv2d([3, 3, 1, 1], 1, True)\n        self.disconv1 = InitConv2d([3, 3, 128, 1024], 2, False)\n\n    def construct(self, src, ref, mask, method='SOFT'):\n        \"\"\"compute attention score\"\"\"\n\n        # get shapes\n        shape_src = src.shape\n        batch_size = shape_src[0]\n        nc = shape_src[1]\n        # raw features\n        raw_feats = self.unfold1(ref)\n        raw_feats = self.transpose(raw_feats, (0, 2, 3, 1))\n        raw_feats = self.reshape(raw_feats, (batch_size, -1, 3, 3, nc))\n        raw_feats = self.transpose(raw_feats, (0, 2, 3, 4, 1))\n        split = ops.Split(0, batch_size)\n        raw_feats_lst = split(raw_feats)\n        # resize\n        src = downsample(src)\n        ref = downsample(ref)\n        ss = src.shape\n        rs = ref.shape\n        src_lst = split(src)\n        feats = self.unfold2(ref)\n        feats = self.transpose(feats, (0, 2, 3, 1))\n        feats = self.reshape(feats, (batch_size, -1, 3, 3, nc))\n        feats = self.transpose(feats, (0, 2, 3, 4, 1))\n        feats_lst = split(feats)\n        # process mask\n        mask = self.pool1(mask)\n        mask = self.pool2(mask)\n        mask = 1 - mask\n        mask = self.reshape(mask, (1, -1, 1, 1))\n\n        y_lst, y_up_lst = [], []\n        offsets = []\n        fuse_weight = self.reshape(self.eye(3, 3, mindspore.float32), (3, 3, 1, 1))\n        for x, r, raw_r in zip(src_lst, feats_lst, raw_feats_lst):\n            r = r[0]\n            r = r / self.maximum(self.sqrt(self.reducesum(self.square(r), [0, 1, 2])), 1e-8)\n            r_kernel = self.transpose(r, (3, 2, 0, 1))\n            y = self.conv1(x, r_kernel)\n            if self.fuse:\n                # conv implementation for fuse scores to encourage large patches\n                yi = self.reshape(y, (1, 1, ss[2] * ss[3], rs[2] * rs[3]))\n                fuse_weight_kernel = ops.Transpose()(fuse_weight, (3, 2, 0, 1))\n                yi = self.conv2(yi, fuse_weight_kernel)\n                yi = self.transpose(yi, (0, 2, 3, 1))\n                yi = self.reshape(yi, (1, ss[2], ss[3], rs[2], rs[3]))\n                yi = self.transpose(yi, (0, 2, 1, 4, 3))\n                yi = self.reshape(yi, (1, ss[2] * ss[3], rs[2] * rs[3], 1))\n                yi = self.transpose(yi, (0, 3, 1, 2))\n                yi = self.conv2(yi, fuse_weight_kernel)\n                yi = self.transpose(yi, (0, 2, 3, 1))\n                yi = self.reshape(yi, (1, ss[3], ss[2], rs[3], rs[2]))\n                yi = self.transpose(yi, (0, 2, 1, 4, 3))\n                y = yi\n            y = self.reshape(y, (1, ss[2], ss[3], rs[2] * rs[3]))\n            y = self.transpose(y, (0, 3, 1, 2))\n            if method == 'HARD':\n                ym = self.reducemax(y, 1)\n                y = y * mask\n                coef = self.greaterequal(y, max(y, 1)).astype(self.dtype)\n                y = self.pow(coef * self.div(y, ym + 1e-04), 2)\n            elif method == 'SOFT':\n                y = (self.softmax(y * mask * self.softmax_scale)) * mask\n            y = self.reshape(y, (1, rs[2] * rs[3], ss[2], ss[3]))\n            if self.dtype == mindspore.float32:\n                offset = y.argmax(1)\n                offsets.append(offset)\n            feats = raw_r[0]\n            feats_kernel = self.transpose(feats, (3, 2, 0, 1))\n            y_up = self.disconv1(y, feats_kernel)\n            y_lst.append(y)\n            y_up_lst.append(y_up)\n        out, correspondence = self.cat(y_up_lst), self.cat(y_lst)\n        out = self.reshape(out, (shape_src[0], shape_src[1], shape_src[2], shape_src[3]))\n        return out, correspondence\n",[274],{"type":18,"tag":56,"props":275,"children":276},{"__ignoreMap":7},[277],{"type":24,"value":272},{"type":18,"tag":26,"props":279,"children":280},{},[281],{"type":18,"tag":98,"props":282,"children":284},{"alt":7,"src":283},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2023/02/22/f85d640f5a5b4aa3874491bc7fb44ffb.png",[],{"type":18,"tag":51,"props":286,"children":288},{"code":287},"class ApplyAttention(nn.Cell):\n    \"\"\"\n    Attention transfer module(used for training)\n    (It generally used for 128 * 128 / 256 * 256 feature map).\n\n    Args:\n        shp(list): the shape of input feature map.\n        shp_att(list): the shape of attention score.\n\n    Return:\n        out: Feature map filled by attention transfer module.\n    \"\"\"\n\n    def __init__(self, shp, shp_att):\n        super(ApplyAttention, self).__init__()\n        self.shp = shp\n        self.shp_att = shp_att\n        self.rate = self.shp[2] // self.shp_att[2]\n        self.kernel = self.rate * 2\n        self.batch_size = self.shp[0]\n        self.sz = self.shp[2]\n        self.nc = self.shp[1]\n        self.unfold = nn.Unfold([1, self.kernel, self.kernel, 1], [1, self.rate, self.rate, 1], [1, 1, 1, 1], 'same')\n        self.transpose = ops.Transpose()\n        self.reshape = ops.Reshape()\n        self.split = ops.Split(0, self.batch_size)\n        self.disconv1 = InitConv2d([8, 8, 64, 1024], self.rate, False)\n        self.disconv2 = InitConv2d([16, 16, 32, 1024], self.rate, False)\n        self.concat = ops.Concat(0)\n        self.conv_pl2 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 1),\n            GatedConv2d(64, 64, 3, 1, 2)\n        )\n        self.conv_pl1 = nn.SequentialCell(\n            GatedConv2d(32, 32, 3, 1, 1),\n            GatedConv2d(32, 32, 3, 1, 2)\n        )\n\n    def construct(self, x, correspondence):\n        \"\"\"apply attention on training\"\"\"\n\n        raw_feats = self.unfold(x)\n        raw_feats = self.transpose(raw_feats, (0, 2, 3, 1))\n        raw_feats = self.reshape(raw_feats, (self.batch_size, -1, self.kernel, self.kernel, self.nc))\n        raw_feats = self.transpose(raw_feats, (0, 2, 3, 4, 1))\n        raw_feats_lst = self.split(raw_feats)\n        ys = []\n        correspondence = self.transpose(correspondence, (0, 2, 3, 1))\n        att_lst = self.split(correspondence)\n        for feats, att in zip(raw_feats_lst, att_lst):\n            feats_kernel = self.transpose(feats[0], (3, 2, 0, 1))\n            att = self.transpose(att, (0, 3, 1, 2))\n            if self.shp[2] == 128:\n                y1 = self.disconv1(att, feats_kernel)\n                ys.append(y1)\n            elif self.shp[2] == 256:\n                y2 = self.disconv2(att, feats_kernel)\n                ys.append(y2)\n            else:\n                print('Value Error')\n        out = self.concat(ys)\n        if self.shp[2] == 128:\n            out = self.conv_pl2(out)\n        elif self.shp[2] == 256:\n            out = self.conv_pl1(out)\n        else:\n            print('conv error')\n        return out\n",[289],{"type":18,"tag":56,"props":290,"children":291},{"__ignoreMap":7},[292],{"type":24,"value":287},{"type":18,"tag":61,"props":294,"children":296},{"id":295},"the-overall-pipeline-of-cra",[297],{"type":18,"tag":40,"props":298,"children":299},{},[300],{"type":24,"value":301},"The Overall Pipeline of CRA",{"type":18,"tag":61,"props":303,"children":305},{"id":304},"给定一个高分辨率输入图像首先将图像下采样到512512低分辨率图像然后对其进行上采样以获得与原始输入相同大小的模糊图像低频分量生成器获取低分辨率图像并进行图像修复同时注意力分数由生成器的注意力计算模块acm计算",[306],{"type":24,"value":307},"给定一个高分辨率输入图像，首先将图像下采样到512×512（低分辨率图像）,然后对其进行上采样以获得与原始输入相同大小的模糊图像（低频分量）。生成器获取低分辨率图像并进行图像修复，同时，注意力分数由生成器的注意力计算模块（ACM）计算。",{"type":18,"tag":61,"props":309,"children":311},{"id":310},"通过从原始输入中减去模糊低频分量计算图像上下文残差contextual-residual然后通过注意力转移模块atm结合注意力分数从上下文残差图像中计算空洞区域的聚合残差aggregated-residual最后将聚合残差添加到上采样的图像修复结果中得到hole区域的最终修复结果而hole外的区域依旧采用原始输入cra机制整体流程如下图",[312],{"type":24,"value":313},"通过从原始输入中减去模糊低频分量计算图像上下文残差(contextual residual)，然后通过注意力转移模块（ATM）结合注意力分数从上下文残差图像中计算空洞区域的聚合残差(aggregated residual)，最后，将聚合残差添加到上采样的图像修复结果中得到hole区域的最终修复结果，而hole外的区域依旧采用原始输入。CRA机制整体流程如下图。",{"type":18,"tag":26,"props":315,"children":316},{},[317],{"type":18,"tag":98,"props":318,"children":320},{"alt":100,"src":319},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222063141.96784108453363681980321828176293:50540221072327:2400:84720D2DE48DC89015582C563042C41CE2EA0448C5AFF25451C7F497F0780277.png",[],{"type":18,"tag":26,"props":322,"children":323},{},[324],{"type":18,"tag":40,"props":325,"children":326},{},[327],{"type":24,"value":328},"生成器",{"type":18,"tag":26,"props":330,"children":331},{},[332],{"type":24,"value":333},"生成器部分采用两阶段的从粗到细的网络架构，其中粗网生成图像修复的粗糙效果，而细网在粗网的基础上预测更精细的结果。生成器将原始图像和mask图像作为输入并生成一张完整的修复图像，输入和输出大小为512×512。为了扩大感知域并减少计算量，在粗网卷积之前将输入下采样到256×256，而对于细网的输入，则是将输入的hole区域替换为粗网的对应区域。",{"type":18,"tag":26,"props":335,"children":336},{},[337],{"type":24,"value":338},"细网使用高级特征图计算上下文注意力分数，并对多个较低级别特征图执行注意力转移，论文还在粗网和细网中采用扩张卷积，进一步扩大感受野的大小。此外，为了提高计算效率，LWGC门控卷积应用于生成器的所有层。对于网络卷积层,统一移除了BN（batch normalization）处理，padding处理采用'same'模式，卷积层的激活函数全部采用ELU激活函数。",{"type":18,"tag":51,"props":340,"children":342},{"code":341},"from src.models.network_module import GatedConv2d, TransposeGatedConv2d\nfrom src.models.compute_attention import ContextualAttention, ApplyAttention\n\n\nclass Coarse(nn.Cell):\n    \"\"\"Build the first stage of generator: coarse network\"\"\"\n\n    def __init__(self):\n        super(Coarse, self).__init__()\n        self.coarse1 = nn.SequentialCell(\n            GatedConv2d(4, 32, 5, 2, 1, sc=True),\n            GatedConv2d(32, 32, 3, 1, 1, sc=True),\n            GatedConv2d(32, 64, 3, 2, 1, sc=True)\n        )\n        self.coarse2 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True)\n        )\n        self.coarse3 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True)\n        )\n        self.coarse4 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 2, sc=True),\n            GatedConv2d(64, 64, 3, 1, 2, sc=True),\n            GatedConv2d(64, 64, 3, 1, 2, sc=True),\n            GatedConv2d(64, 64, 3, 1, 2, sc=True),\n            GatedConv2d(64, 64, 3, 1, 2, sc=True)\n        )\n        self.coarse5 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 4, sc=True),\n            GatedConv2d(64, 64, 3, 1, 4, sc=True),\n            GatedConv2d(64, 64, 3, 1, 4, sc=True),\n            GatedConv2d(64, 64, 3, 1, 4, sc=True)\n        )\n        self.coarse6 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 8, sc=True),\n            GatedConv2d(64, 64, 3, 1, 8, sc=True),\n        )\n        self.coarse7 = nn.SequentialCell(\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n            GatedConv2d(64, 64, 3, 1, 1, sc=True),\n        )\n        self.coarse8 = nn.SequentialCell(\n            TransposeGatedConv2d(64, 32, 3, 1, 1, sc=True),\n            GatedConv2d(32, 32, 3, 1, 1, sc=True),\n            TransposeGatedConv2d(32, 3, 3, 1, 1, sc=True),\n        )\n\n    def construct(self, first_in):\n        first_out = self.coarse1(first_in)\n        first_out = self.coarse2(first_out)\n        first_out = self.coarse3(first_out)\n        first_out = self.coarse4(first_out)\n        first_out = self.coarse5(first_out)\n        first_out = self.coarse6(first_out)\n        first_out = self.coarse7(first_out)\n        first_out = self.coarse8(first_out)\n        first_out = ops.clip_by_value(first_out, -1, 1)\n        return first_out\n\n\nclass GatedGenerator(nn.Cell):\n    \"\"\"\n    Build the second stage of generator: refine network and complete generator.\n\n    Args:\n        opt(class): option class.\n\n    Return:\n        first_out: The output of coarse network.\n        second_out: The output of refine network.\n        match: Attention score.\n    \"\"\"\n\n    def __init__(self, opt):\n        super(GatedGenerator, self).__init__()\n        self.coarse = Coarse()\n        self.refinement1 = nn.SequentialCell(\n            GatedConv2d(4, 32, 3, 2, 1),\n            GatedConv2d(32, 32, 3, 1, 1)\n        )\n        self.refinement2 = nn.SequentialCell(\n            GatedConv2d(32, 64, 3, 2, 1),\n            GatedConv2d(64, 64, 3, 1, 1)\n        )\n        self.refinement3 = nn.SequentialCell(\n            GatedConv2d(64, 128, 3, 2, 1),\n            GatedConv2d(128, 128, 3, 1, 1)\n        )\n        self.refinement4 = GatedConv2d(128, 128, 3, 1, 1)\n        self.refinement5 = nn.SequentialCell(\n            GatedConv2d(128, 128, 3, 1, 2),\n            GatedConv2d(128, 128, 3, 1, 4)\n        )\n        self.refinement6 = nn.SequentialCell(\n            GatedConv2d(128, 128, 3, 1, 8),\n            GatedConv2d(128, 128, 3, 1, 16)\n        )\n        self.refinement7 = nn.SequentialCell(\n            TransposeGatedConv2d(128, 64, 3, 1, 1),\n            GatedConv2d(64, 64, 3, 1, 1)\n        )\n        self.refinement8 = nn.SequentialCell(\n            TransposeGatedConv2d(128, 32, 3, 1, 1),\n            GatedConv2d(32, 32, 3, 1, 1)\n        )\n        self.refinement9 = TransposeGatedConv2d(64, 3, 3, 1, 1)\n        self.conv_att1 = GatedConv2d(128, 128, 3, 1, 1)\n        self.conv_att2 = GatedConv2d(256, 128, 3, 1, 1)\n        self.batch = opt.train_batchsize\n        self.apply_attention1 = ApplyAttention([self.batch, 64, 128, 128], [self.batch, 1024, 32, 32])\n        self.apply_attention2 = ApplyAttention([self.batch, 32, 256, 256], [self.batch, 1024, 32, 32])\n        self.ones = ops.Ones()\n        self.concat = ops.Concat(1)\n        self.bilinear_256 = ops.ResizeBilinear((256, 256))\n        self.bilinear_512 = ops.ResizeBilinear((512, 512))\n        self.reshape = ops.Reshape()\n        self.contextual_attention = ContextualAttention(fuse=True, dtype=mindspore.float32)\n        self.cat = ops.Concat(1)\n        self.method = opt.attention_type\n\n    def construct(self, img, mask):\n        x_in = img.astype(mindspore.float32)\n        shape = x_in.shape\n        mask_batch = self.ones((shape[0], 1, shape[2], shape[3]), mindspore.float32)\n        mask_batch = mask_batch * mask\n        first_in = self.concat((x_in, mask_batch))\n        first_in = self.bilinear_256(first_in)\n        first_out = self.coarse(first_in)\n        first_out = self.bilinear_512(first_out)\n        first_out = self.reshape(first_out, (shape[0], shape[1], shape[2], shape[3]))\n        x_coarse = first_out * mask_batch + x_in * (1. - mask_batch)\n        second_in = self.concat([x_coarse, mask_batch])\n        pl1 = self.refinement1(second_in)\n        pl2 = self.refinement2(pl1)\n        second_out = self.refinement3(pl2)\n        second_out = self.refinement4(second_out)\n        second_out = self.refinement5(second_out)\n        pl3 = self.refinement6(second_out)\n        x_hallu = pl3\n        x, match = self.contextual_attention(pl3, pl3, mask, self.method)\n        x = self.conv_att1(x)\n        x = self.cat((x_hallu, x))\n        second_out = self.conv_att2(x)\n        second_out = self.refinement7(second_out)\n        second_out_att = self.apply_attention1(pl2, match)\n        second_out = self.concat([second_out_att, second_out])\n        second_out = self.refinement8(second_out)\n        second_out_att = self.apply_attention2(pl1, match)\n        second_out = self.concat([second_out_att, second_out])\n        second_out = self.refinement9(second_out)\n        second_out = ops.clip_by_value(second_out, -1, 1)\n        return first_out, second_out, match\n",[343],{"type":18,"tag":56,"props":344,"children":345},{"__ignoreMap":7},[346],{"type":24,"value":341},{"type":18,"tag":26,"props":348,"children":349},{},[350],{"type":18,"tag":40,"props":351,"children":352},{},[353],{"type":24,"value":354},"判别器",{"type":18,"tag":26,"props":356,"children":357},{},[358],{"type":24,"value":359},"判别器D通过一系列的Conv2d和LeakyReLU层对其进行处理，最后通过nn.Dense函数输出最终判别结果。判别器的代码实现如下：",{"type":18,"tag":51,"props":361,"children":363},{"code":362},"from src.models.network_module import Conv2dLayer\n\n\nclass Discriminator(nn.Cell):\n    \"\"\"Build the complete discriminator\"\"\"\n\n    def __init__(self):\n        super(Discriminator, self).__init__()\n        self.block1 = Conv2dLayer(3, 64, 5, 2, 1)\n        self.block2 = Conv2dLayer(64, 128, 5, 2, 1)\n        self.block3 = Conv2dLayer(128, 256, 5, 2, 1)\n        self.block4 = Conv2dLayer(256, 256, 5, 2, 1)\n        self.block5 = Conv2dLayer(256, 256, 5, 2, 1)\n        self.block6 = Conv2dLayer(256, 256, 5, 2, 1)\n        self.block7 = nn.Dense(16384, 1)\n\n    def construct(self, img):\n        x = img\n        x = self.block1(x)\n        x = self.block2(x)\n        x = self.block3(x)\n        x = self.block4(x)\n        x = self.block5(x)\n        x = self.block6(x)\n        x = x.reshape([x.shape[0], -1])\n        x = self.block7(x)\n        return x\n",[364],{"type":18,"tag":56,"props":365,"children":366},{"__ignoreMap":7},[367],{"type":24,"value":362},{"type":18,"tag":61,"props":369,"children":371},{"id":370},"构建连接网络和损失函数",[372],{"type":18,"tag":40,"props":373,"children":374},{},[375],{"type":24,"value":370},{"type":18,"tag":26,"props":377,"children":378},{},[379],{"type":24,"value":380},"昇思mindspore将损失函数、优化器等操作都封装到了Cell中，这种设计给实现GAN带来了一些不方便。因为GAN结构上的特殊性，它和一般的分类网络不同，损失由判别器损失和生成器损失组成，是多输出的。如果直接使用Cell包，框架会不知道Loss和网络是如何连接的，会导致无法训练。所以，我们需要自定义WithLossCell，将网络和Loss连接起来。",{"type":18,"tag":26,"props":382,"children":383},{},[384],{"type":24,"value":385},"对于生成器损失，我们分别构建了对抗性损失Ladv和重建损失Lrec：",{"type":18,"tag":26,"props":387,"children":388},{},[389],{"type":18,"tag":98,"props":390,"children":392},{"alt":100,"src":391},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222063316.47357412358620178793074453824699:50540221072327:2400:D46F4718DD5BC7565172CEDBFE3422FB61FE4211CFDBEF070BCB12326C990D99.png",[],{"type":18,"tag":26,"props":394,"children":395},{},[396],{"type":24,"value":397},"其中，α 一般取1.2，α1 取1.2，α2 取1.2，β 取0.001.",{"type":18,"tag":51,"props":399,"children":401},{"code":400},"from src.models.cra_utils.utils import gan_wgan_loss\n\n\nclass GenWithLossCell(nn.Cell):\n    \"\"\"\n    Build the generator loss.\n\n    Args:\n        net_g(cell): generator network.\n        net_d(cell): discriminator network.\n        args(class): option class.\n        auto_prefix(bool): whether to automatically generate namespace for cell and its subcells.\n            If set to True, the network parameter name will be prefixed, otherwise it will not.\n\n    Return:\n        loss_g: the loss of generator.\n    \"\"\"\n\n    def __init__(self, net_g, net_d, args, auto_prefix=True):\n        super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix)\n        self.net_g = net_g\n        self.net_d = net_d\n        self.gan_wgan_loss = gan_wgan_loss\n        self.coarse_alpha = args.coarse_alpha\n        self.gan_with_mask = args.gan_with_mask\n        self.gan_loss_alpha = args.gan_loss_alpha\n        self.in_hole_alpha = args.in_hole_alpha\n        self.context_alpha = args.context_alpha\n        self.train_batchsize = args.train_batchsize\n        self.mean = ops.ReduceMean(False)\n        self.abs = ops.Abs()\n        self.concat_0 = ops.Concat(0)\n        self.concat_1 = ops.Concat(1)\n        self.split = ops.Split(0, 2)\n        self.tile = ops.Tile()\n\n    def construct(self, real, x, mask):\n        x1, x2, _ = self.net_g(x, mask)\n        fake = x2\n        losses = {}\n        fake_patched = fake * mask + real * (1 - mask)\n        fake_patched = fake_patched.astype(mindspore.float32)\n        losses['in_hole_loss'] = self.coarse_alpha * self.mean(self.abs(real - x1) * mask)\n        losses['in_hole_loss'] = losses['in_hole_loss'] + self.mean(self.abs(real - x2) * mask)\n        losses['context_loss'] = self.coarse_alpha * self.mean(self.abs(real - x1) * (1 - mask))\n        losses['context_loss'] = losses['context_loss'] + self.mean(self.abs(real - x2) * (1 - mask))\n        losses['context_loss'] = losses['context_loss'] / self.mean(1 - mask)\n        real_fake = self.concat_0((real, fake_patched))\n        if self.gan_with_mask:\n            real_fake = self.concat_1((real_fake, self.tile(mask, (self.train_batchsize * 2, 1, 1, 1))))\n        d_real_fake = self.net_d(real_fake)\n        d_real, d_fake = self.split(d_real_fake)\n        g_loss, _ = self.gan_wgan_loss(d_real, d_fake)\n        losses['adv_gloss'] = g_loss\n        losses['g_loss'] = self.gan_loss_alpha * losses['adv_gloss']\n        losses['g_loss'] = losses['g_loss'] + self.in_hole_alpha * losses['in_hole_loss']\n        losses['g_loss'] = losses['g_loss'] + self.context_alpha * losses['context_loss']\n        loss_g = losses['g_loss']\n        return loss_g\n",[402],{"type":18,"tag":56,"props":403,"children":404},{"__ignoreMap":7},[405],{"type":24,"value":400},{"type":18,"tag":26,"props":407,"children":408},{},[409],{"type":24,"value":410},"对于判别器损失，我们加入了WGAN-GP loss，加强了第二阶段细化网络的全局一致性：",{"type":18,"tag":26,"props":412,"children":413},{},[414],{"type":18,"tag":98,"props":415,"children":418},{"alt":416,"src":417},"cke_170436.png","https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222063410.39269388701715290722225722562141:50540221072327:2400:319EC56D14F68D6A68DBE5FE3018686819182ECEA95C8E3BE36E7A862BAC5ED7.png",[],{"type":18,"tag":51,"props":420,"children":422},{"code":421},"from src.models.cra_utils.utils import random_interpolates, GradientsPenalty\n\n\nclass DisWithLossCell(nn.Cell):\n    \"\"\"\n    Build the discriminator loss.\n\n    Args:\n        net_g(cell): generator network.\n        net_d(cell): discriminator network.\n        args(class): option class.\n        auto_prefix(bool): whether to automatically generate namespace for cell and its subcells.\n            If set to True, the network parameter name will be prefixed, otherwise it will not.\n\n    Return:\n        loss_d: the loss of discriminator.\n    \"\"\"\n\n    def __init__(self, net_g, net_d, args, auto_prefix=True):\n        super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix)\n        self.net_g = net_g\n        self.net_d = net_d\n        self.gan_wgan_loss = gan_wgan_loss\n        self.random_interpolates = random_interpolates\n        self.gradients_penalty = GradientsPenalty(self.net_d)\n        self.gan_with_mask = args.gan_with_mask\n        self.wgan_gp_lambda = args.wgan_gp_lambda\n        self.train_batchsize = args.train_batchsize\n        self.concat_0 = ops.Concat(0)\n        self.concat_1 = ops.Concat(1)\n        self.split = ops.Split(0, 2)\n\n    def construct(self, real, x, mask):\n        _, x2, _ = self.net_g(x, mask)\n        fake = x2\n        losses = {}\n        fake_patched = fake * mask + real * (1 - mask)\n        fake_patched = fake_patched.astype(mindspore.float32)\n        real_fake = self.concat_0((real, fake_patched))\n        if self.gan_with_mask:\n            real_fake = self.concat_1((real_fake, ops.Tile()(mask, (self.train_batchsize * 2, 1, 1, 1))))\n        d_real_fake = self.net_d(real_fake)\n        d_real, d_fake = self.split(d_real_fake)\n        _, d_loss = self.gan_wgan_loss(d_real, d_fake)\n        losses['adv_dloss'] = d_loss\n        interps = self.random_interpolates(real, fake_patched)\n        gp_loss = self.gradients_penalty(interps)\n        losses['gp_loss'] = self.wgan_gp_lambda * gp_loss\n        losses['d_loss'] = losses['adv_dloss'] + losses['gp_loss']\n        loss_d = losses['d_loss']\n        return loss_d\n",[423],{"type":18,"tag":56,"props":424,"children":425},{"__ignoreMap":7},[426],{"type":24,"value":421},{"type":18,"tag":26,"props":428,"children":429},{},[430],{"type":24,"value":431},"搭建损失函数与网络的连接，定义训练网络封装类：",{"type":18,"tag":51,"props":433,"children":435},{"code":434},"import mindspore.ops.functional as F\nfrom mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode)\nfrom mindspore.context import ParallelMode\nfrom mindspore.nn.wrap.grad_reducer import DistributedGradReducer\n\n\nclass TrainOneStepD(nn.Cell):\n    \"\"\"Encapsulation class of discriminator network training.\"\"\"\n\n    def __init__(self, d, optimizer, sens=1.0):\n        super(TrainOneStepD, self).__init__(auto_prefix=True)\n        self.optimizer = optimizer\n        self.d = d\n        self.d.net_d.set_grad()\n        self.d.net_d.set_train()\n        self.d.net_g.set_grad(False)\n        self.d.net_g.set_train(False)\n        self.grad = ops.GradOperation(get_by_list=True, sens_param=True)\n        self.sens = sens\n        self.weights = optimizer.parameters\n        self.reducer_flag = False\n        self.fill = ops.Fill()\n        self.dtype = ops.DType()\n        self.shape = ops.Shape()\n        self.grad_reducer = F.identity\n        self.parallel_mode = _get_parallel_mode()\n        if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):\n            self.reducer_flag = True\n        if self.reducer_flag:\n            mean = _get_gradients_mean()\n            degree = _get_device_num()\n            self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)\n\n    def construct(self, real, x, mask):\n        weights = self.weights\n        loss_d = self.d(real, x, mask)\n        sens_d = self.fill(self.dtype(loss_d), self.shape(loss_d), self.sens)\n        grads_d = self.grad(self.d, weights)(real, x, mask, sens_d)\n        if self.reducer_flag:\n            grads_d = self.grad_reducer(grads_d)\n        self.optimizer(grads_d)\n        return loss_d\n\n\nclass TrainOneStepG(nn.Cell):\n    \"\"\"Encapsulation class of generator network training.\"\"\"\n\n    def __init__(self, g, optimizer, sens=1.0):\n        super(TrainOneStepG, self).__init__(auto_prefix=True)\n        self.optimizer = optimizer\n        self.g = g\n        self.g.net_g.set_grad()\n        self.g.net_g.set_train()\n        self.g.net_d.set_grad(False)\n        self.g.net_d.set_train(False)\n        self.grad = ops.GradOperation(get_by_list=True, sens_param=True)\n        self.sens = sens\n        self.weights = optimizer.parameters\n        self.reducer_flag = False\n        self.fill = ops.Fill()\n        self.dtype = ops.DType()\n        self.shape = ops.Shape()\n        self.grad_reducer = F.identity\n        self.parallel_mode = _get_parallel_mode()\n        if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):\n            self.reducer_flag = True\n        if self.reducer_flag:\n            mean = _get_gradients_mean()\n            degree = _get_device_num()\n            self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)\n\n    def construct(self, real, x, mask):\n        weights = self.weights\n        loss_g = self.g(real, x, mask)\n        sens_g = self.fill(self.dtype(loss_g), self.shape(loss_g), self.sens)\n        grads_g = self.grad(self.g, weights)(real, x, mask, sens_g)\n        if self.reducer_flag:\n            grads_g = self.grad_reducer(grads_g)\n        self.optimizer(grads_g)\n        return loss_g\n",[436],{"type":18,"tag":56,"props":437,"children":438},{"__ignoreMap":7},[439],{"type":24,"value":434},{"type":18,"tag":61,"props":441,"children":443},{"id":442},"构建优化器",[444],{"type":18,"tag":40,"props":445,"children":446},{},[447],{"type":24,"value":442},{"type":18,"tag":51,"props":449,"children":451},{"code":450},"net_g = GatedGenerator(config)\nnet_d = Discriminator()\nlr = nn.exponential_decay_lr(config.learning_rate, config.lr_decrease_factor, total_batch * config.epochs, total_batch,\n                             config.lr_decrease_epoch, True)\noptimizer_g = nn.Adam(filter(lambda p: p.requires_grad, net_g.trainable_params()), lr, 0.5, 0.9)\noptimizer_d = nn.Adam(net_d.trainable_params(), lr, 0.5, 0.9)\n",[452],{"type":18,"tag":56,"props":453,"children":454},{"__ignoreMap":7},[455],{"type":24,"value":450},{"type":18,"tag":26,"props":457,"children":458},{},[459],{"type":24,"value":460},"这里我们设置了两个单独的优化器，分别用于判别器和生成器，参数设定统一 beta1=0.5, beta2=0.9,学习率采用指数衰减函数自动更新。",{"type":18,"tag":61,"props":462,"children":464},{"id":463},"训练模型",[465],{"type":18,"tag":40,"props":466,"children":467},{},[468],{"type":24,"value":463},{"type":18,"tag":26,"props":470,"children":471},{},[472],{"type":24,"value":473},"训练分为两个主要部分：训练判别器和训练生成器。训练判别器的作用是为了能够更好的识别真伪，尽量把生成器生成的图片和真实的图片分别开来；训练生成器的作用是为了尽量生成近似真实的虚假图片。",{"type":18,"tag":26,"props":475,"children":476},{},[477],{"type":24,"value":478},"训练过程：",{"type":18,"tag":51,"props":480,"children":482},{"code":481},"import cv2\nimport time\n\nfrom mindspore import context, save_checkpoint, nn\n\nfrom src.config.config import cra_config\nfrom src.models.inpainting_network import GatedGenerator, Discriminator\nfrom src.models.loss import GenWithLossCell, DisWithLossCell\nfrom src.models.train_one_step import TrainOneStepD, TrainOneStepG\n\n\ndef trainer(args):\n    \"\"\"Train model.\"\"\"\n\n    # Preprocess the data for training\n    context.set_context(mode=context.GRAPH_MODE, device_target='GPU')\n    dataset_generator = InpaintDataset(args)\n    dataset_size = len(dataset_generator)\n    total_batch = dataset_size // args.train_batchsize\n    dataset = ds.GeneratorDataset(dataset_generator, ['image'])\n    dataset = dataset.batch(args.train_batchsize, drop_remainder=True)\n    dataset = dataset.create_dict_iterator()\n\n    # Network\n    net_g = GatedGenerator(args)\n    net_d = Discriminator()\n    netg_with_loss = GenWithLossCell(net_g, net_d, args)\n    netd_with_loss = DisWithLossCell(net_g, net_d, args)\n    lr = nn.exponential_decay_lr(args.learning_rate, args.lr_decrease_factor, total_batch * 10, total_batch,\n                                 args.lr_decrease_epoch, True)\n    optimizer_g = nn.Adam(filter(lambda p: p.requires_grad, net_g.trainable_params()), lr, 0.5, 0.9)\n    optimizer_d = nn.Adam(net_d.trainable_params(), lr, 0.5, 0.9)\n    train_discriminator = TrainOneStepD(netd_with_loss, optimizer_d)\n    train_generator = TrainOneStepG(netg_with_loss, optimizer_g)\n\n    # Train\n    train_discriminator.set_train()\n    train_generator.set_train()\n    print(\"Starting Training Loop...\")\n    for epoch in range(10):\n        for batch_idx, image in enumerate(dataset):\n            s = time.time()\n            real = image['image']\n            real = real.astype(mindspore.float32)\n            mask, _ = random_mask(args)\n            x = real * (1 - mask)\n            for _ in range(args.dis_iter):\n                netd_loss = train_discriminator(real, x, mask)\n            netg_loss = train_generator(real, x, mask)\n            gap = time.time() - s\n            # Print losses\n            print('epoch{}/{}, batch{}/{}, d_loss is {:.4f}, g_loss is {:.4f}, time is {:.4f}'.format(\n                epoch + 1, args.epochs, batch_idx + 1, total_batch, netd_loss.asnumpy(), netg_loss.asnumpy(), gap))\n            save_checkpoint_path = './ckpt_out'\n            if not os.path.isdir(save_checkpoint_path):\n                os.makedirs(save_checkpoint_path)\n            # Save checkpoint\n            gen_name = 'generator_epoch%d_batch%d.ckpt' % (epoch + 1, batch_idx + 1)\n            dis_name = 'discriminator_epoch%d_batch%d.ckpt' % (epoch + 1, batch_idx + 1)\n            gen_name = os.path.join(save_checkpoint_path, gen_name)\n            dis_name = os.path.join(save_checkpoint_path, dis_name)\n            if (batch_idx + 1) == total_batch:\n                save_checkpoint(train_generator, gen_name)\n                save_checkpoint(train_discriminator, dis_name)\ntrainer(cra_config)\n",[483],{"type":18,"tag":56,"props":484,"children":485},{"__ignoreMap":7},[486],{"type":24,"value":481},{"type":18,"tag":61,"props":488,"children":490},{"id":489},"模型推理",[491],{"type":18,"tag":40,"props":492,"children":493},{},[494],{"type":24,"value":489},{"type":18,"tag":26,"props":496,"children":497},{},[498],{"type":24,"value":499},"在完成生成对抗网络训练后，我们可以使用GAN网络来预测低分辨率图像的修复结果，但要生成一张完整的高分辨率的修复图像，我们还需要做一些后处理操作，具体为：获取图像上下文残差信息；通过高频残差和注意力机制来生成缺失内容的聚合残差；对GAN网络生成图像进行上采样；将聚合残差添加到大而模糊的生成图像中获得清晰修复图像；将修复图像处理到与原始待修复图像同样尺寸。",{"type":18,"tag":51,"props":501,"children":503},{"code":502},"import glob\nimport cv2\nimport numpy as np\n\n\ndef sort(str_lst):\n    \"\"\"Return the sorted list in ascending order.\"\"\"\n\n    return [s for s in sorted(str_lst)]\n\n\ndef read_imgs_masks(args):\n    \"\"\"Sort the image and mask directories in order and return it.\"\"\"\n\n    paths_img = glob.glob(args.image_dir + '/*.*[g|G]')\n    paths_img = sort(paths_img)\n    paths_mask = glob.glob(args.mask_dir + '/*.*[g|G]')\n    paths_mask = sort(paths_mask)\n    return paths_img, paths_mask\n\n\ndef get_input(path_img, path_mask):\n    \"\"\"Read and process the image and mask through the given path.\"\"\"\n\n    image = cv2.imread(path_img)\n    mask = cv2.imread(path_mask)\n    image = np.expand_dims(image, 0)\n    mask = np.expand_dims(mask, 0)\n    return image[0], mask[0]\nfrom mindspore import nn, ops\n\nfrom src.models.inpainting_network import GatedGenerator\nfrom src.models.compute_attention import ApplyAttention2\n\n\ndef post_processing(large_img, small_img, low_base, small_mask, corres, args):\n    \"\"\"Subtracting the large blurry image from the raw input to compute contextual residuals,\n     and calculate aggregated residuals through attention transfer module.\n     Adding the aggregated residuals to the up-sampled generator inpainted result.\"\"\"\n\n    high_raw = large_img\n    low_raw = small_img\n    mask = 1 - small_mask\n    low_raw = nn.ResizeBilinear()(low_raw, scale_factor=args.times)\n    to_shape = list(ops.Shape()(mask))[2:]\n    to_shape[0], to_shape[1] = int(to_shape[0] * args.times), int(to_shape[1] * args.times)\n    resize = ops.ResizeNearestNeighbor((to_shape[0], to_shape[1]))\n    mask = resize(mask)\n    residual1 = (high_raw - low_raw) * mask\n    residual = ApplyAttention2([1, 3, 4096, 4096], [1, 1024, 32, 32])(residual1, corres)\n    low_base = nn.ResizeBilinear()(low_base, scale_factor=args.times)\n    x = low_base + residual\n    x = x.clip(-1, 1)\n    x = (x + 1.) * 127.5\n    return x, low_raw, low_base, residual\nfrom scipy import signal\n\nimport mindspore\nfrom mindspore import Tensor\n\n\ndef gaussian_kernel(size, std):\n    \"\"\"Return a gaussian kernel.\"\"\"\n\n    k = signal.gaussian(size, std)\n    kk = np.matmul(k[:, np.newaxis], [k])\n    return kk / np.sum(kk)\n\n\ndef resize_back(raw_img, large_output, small_mask):\n    \"\"\"Process the test output result in the format of [1, 3,4096,4096] to the same size as the original input image.\"\"\"\n\n    raw_shp = raw_img.shape\n    raw_size_output = nn.ResizeBilinear()(large_output, size=(raw_shp[2], raw_shp[3]))\n    raw_size_output = raw_size_output.astype(mindspore.float32)\n    gauss_kernel = gaussian_kernel(7, 1.)\n    gauss_kernel = Tensor(gauss_kernel)\n    gauss_kernel = gauss_kernel.astype(mindspore.float32)\n    gauss_kernel = ops.ExpandDims()(gauss_kernel, 2)\n    gauss_kernel = ops.ExpandDims()(gauss_kernel, 3)\n    a, b, c, d = ops.Shape()(gauss_kernel)\n    gauss_kernel = ops.Transpose()(gauss_kernel, (3, 2, 0, 1))\n    conv = nn.Conv2d(c, d, (a, b), 1, pad_mode='same', padding=0, weight_init=gauss_kernel, data_format='NCHW')\n    mask = conv(small_mask[:, 0:1, :, :])\n    mask = nn.ResizeBilinear()(mask, size=(raw_shp[2], raw_shp[3]))\n    mask = mask.astype(mindspore.float32)\n    raw_size_output = raw_size_output * mask + raw_img * (1 - mask)\n    raw_size_output = ops.Transpose()(raw_size_output, (0, 2, 3, 1))\n    raw_size_output = raw_size_output.astype(mindspore.uint8)\n    return raw_size_output\ndef build_inference_graph(real, mask, model_gen):\n    \"\"\"Input real and mask to generator and output the results.\"\"\"\n\n    mask = mask[0:1, 0:1, :, :]\n    x = real * (1. - mask)\n    _, x2, corres = model_gen(x, mask)\n    fake = x2\n    fake_patched = fake * mask + x * (1 - mask)\n    return x2, fake_patched, corres\n\n\ndef build_inference_net(raw_img_ph, raw_mask_ph, model_gen, args):\n    \"\"\"\n    Complete CRA network testing model, including image preprocessing, generator generation and output,\n        and image post-processing operations.\n\n    Args:\n        raw_img_ph(Tensor): image read from folder.\n            It is processed into the format of [1,3,512,512], the data type is float32, and normalized.\n        raw_mask_ph(Tensor): mask read from folder.\n            It is processed into the format of [1,3,512,512], the data type is float32, and normalized.\n        model_gen(cell): generation network.\n        args(class): option class.\n\n    Return:\n        raw_size_output: Large test output results.\n        raw_img_ph: Image read from folder.\n        raw_mask_ph: Mask read from folder.\n    \"\"\"\n\n    # Process input image\n    raw_img = ops.ExpandDims()(raw_img_ph, 0)\n    raw_img = raw_img.astype(mindspore.float32)\n    raw_img = ops.Transpose()(raw_img, (0, 3, 1, 2))\n    resize = ops.ResizeNearestNeighbor((args.times * args.input_size, args.times * args.input_size))\n    large_img = resize(raw_img)\n    large_img = ops.Reshape()(large_img, (1, 3, args.times * args.input_size, args.times * args.input_size))\n    large_img = large_img / 127.5 - 1\n    net = nn.Unfold([1, args.times, args.times, 1], [1, args.times, args.times, 1], [1, 1, 1, 1], 'same')\n    small_img = net(large_img)\n    small_img = ops.Transpose()(small_img, (0, 2, 3, 1))\n    small_img = ops.Reshape()(small_img, (1, args.input_size, args.input_size, args.times, args.times, 3))\n    small_img = ops.ReduceMean(False)(small_img, axis=(3, 4))\n    small_img = ops.Transpose()(small_img, (0, 3, 1, 2))\n    # Process input mask\n    raw_mask = ops.ExpandDims()(raw_mask_ph, 0)\n    raw_mask = raw_mask.astype(mindspore.float32)\n    raw_mask = ops.Transpose()(raw_mask, (0, 3, 1, 2))\n    resize = ops.ResizeNearestNeighbor((args.input_size, args.input_size))\n    small_mask = resize(raw_mask)\n    small_mask = ops.Reshape()(small_mask, (1, 3, args.input_size, args.input_size))\n    small_mask = 1 - small_mask / 255\n    # Input image and mask to genenrator\n    x2, _, corres = build_inference_graph(real=small_img, mask=small_mask, model_gen=model_gen)\n    # Post processing\n    large_output, _, _, _ = post_processing(large_img, small_img, x2, small_mask, corres, args)\n    # Resize back\n    raw_size_output = resize_back(raw_img, large_output, small_mask)\n    return raw_size_output, raw_img_ph, raw_mask_ph\n",[504],{"type":18,"tag":56,"props":505,"children":506},{"__ignoreMap":7},[507],{"type":24,"value":502},{"type":18,"tag":26,"props":509,"children":510},{},[511],{"type":24,"value":512},"推理代码如下：",{"type":18,"tag":51,"props":514,"children":516},{"code":515},"import os\nimport time\nimport argparse\nimport progressbar\n\nfrom mindspore import context, load_checkpoint, load_param_into_net\n\n\ndef parse_args():\n    \"\"\"Parse parameters.\"\"\"\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_dir', default='./test/images', type=str, help='The directory of images to be tested.')\n    parser.add_argument('--mask_dir', default='./test/masks', type=str, help='The directory of masks.')\n    parser.add_argument('--output_dir', default='./output', type=str, help='Where to write testing output.')\n    parser.add_argument('--checkpoint_dir', default='./ckpt_out/generator_epoch10_batch4.ckpt', type=str,\n                        help='The directory of loading checkpoint.')\n    parser.add_argument('--attention_type', default='SOFT', type=str, help='compute attention type.')\n    parser.add_argument('--train_batchsize', default=1, type=int, help='Batch size for testing.')\n    parser.add_argument('--input_size', default=512, type=int, help='The image size of the input network in the test.')\n    parser.add_argument('--times', default=8, type=int, help='The scaling size of input image.')\n    return parser.parse_args(args=[])\n\n\n# setting test data\ncra_config = parse_args()\nimg_paths, mask_paths = read_imgs_masks(cra_config)\nif not os.path.exists(cra_config.output_dir):\n    os.makedirs(cra_config.output_dir)\ntotal_time = 0\nbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=[progressbar.Bar('=', '[', ']'), ' ',\n                                                              progressbar.Percentage()])\nbar.start()\n# load net and checkpoint file\ngen = GatedGenerator(cra_config)\nparam_dict = load_checkpoint(cra_config.checkpoint_dir)\nload_param_into_net(gen, param_dict)\n#test\nfor (i, img_path) in enumerate(img_paths):\n    rint = i % len(mask_paths)\n    bar.update(i + 1)\n    img_test, mask_test = get_input(img_path, mask_paths[rint])\n    s = time.time()\n    input_img_ph = Tensor(img_test)\n    input_mask_ph = Tensor(255 - mask_test)\n    outputs, input_img_ph, input_mask_ph = build_inference_net(input_img_ph, input_mask_ph, gen, cra_config)\n    res = outputs[0]\n    res = res.asnumpy()\n    total_time += time.time() - s\n    img_hole = img_test * (1 - mask_test / 255) + mask_test\n    res = np.concatenate([img_test, img_hole, res], axis=1)\n    cv2.imwrite(cra_config.output_dir + '/' + str(i) + '.jpg', res)\n    print('test finish')\nbar.finish()\nprint('average time per image', total_time / len(img_paths))\n",[517],{"type":18,"tag":56,"props":518,"children":519},{"__ignoreMap":7},[520],{"type":24,"value":515},{"type":18,"tag":26,"props":522,"children":523},{},[524],{"type":18,"tag":98,"props":525,"children":527},{"alt":7,"src":526},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2023/02/22/959c963c389c4aa28810fc7aa57d0a96.png",[],{"type":18,"tag":171,"props":529,"children":531},{"id":530},"引用",[532],{"type":18,"tag":40,"props":533,"children":534},{},[535],{"type":24,"value":530},{"type":18,"tag":26,"props":537,"children":538},{},[539],{"type":24,"value":540},"[1] Z. Yi, Q. Tang, S. Azizi, D. Jang and Z. Xu. Contextual Residual Aggregation for Ultra High-Resolution Image Inpainting[J]. 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020, pp. 7505-7514.",{"type":18,"tag":26,"props":542,"children":543},{},[544,546],{"type":24,"value":545},"更多昇思MindSpore应用案例请访问官网开发者案例：",{"type":18,"tag":547,"props":548,"children":552},"a",{"href":549,"rel":550},"https://www.mindspore.cn/resources/cases",[551],"nofollow",[553],{"type":24,"value":549},{"type":18,"tag":26,"props":555,"children":556},{},[557],{"type":18,"tag":98,"props":558,"children":560},{"alt":100,"src":559},"https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/e64/154/b38/90a1d5d431e64154b387b3660e356ff5.20230222063857.14917177604882183380155878096812:50540221072327:2400:81569A38454CBBB312D5C66B1394485EBD654D9754A9346E339B4FD1F53A023D.png",[],{"title":7,"searchDepth":562,"depth":562,"links":563},4,[564,566,567,578],{"id":63,"depth":565,"text":63},3,{"id":105,"depth":565,"text":105},{"id":173,"depth":568,"text":179,"children":569},2,[570,571,572,573,574,575,576,577],{"id":229,"depth":562,"text":235},{"id":295,"depth":565,"text":301},{"id":304,"depth":565,"text":307},{"id":310,"depth":565,"text":313},{"id":370,"depth":565,"text":370},{"id":442,"depth":565,"text":442},{"id":463,"depth":565,"text":463},{"id":489,"depth":565,"text":489},{"id":530,"depth":568,"text":530},"markdown","content:technology-blogs:zh:2173.md","content","technology-blogs/zh/2173.md","technology-blogs/zh/2173","md",1776506120511]