代码
MindSpore为你添彩-自动着色算法之Colorization实践

MindSpore为你添彩-自动着色算法之Colorization实践

MindSpore为你添彩-自动着色算法之Colorization实践

自动着色算法之Colorization

工程代码地址

模型简介

Colorization算法是来自加里福利亚大学的一项研究,采用的是CNN的结构。该算法可以实现灰度图像的自动着色,由Richard Zhang等人在论文Colorful Image Colorization中提出,并发表在2016年的ECCV会议中。该模型由8个conv层组成,每个conv层由2个或3个重复的卷积层和ReLU层组成,后面跟着一个BatchNorm层。网络中不包含池化层。

网络特点

设计了一个合适的损失函数来处理着色问题中的多模不确定性,维持了颜色的多样性。

将图像着色任务转化为一个自监督表达学习的任务。

在一些基准模型上获得了最好的效果。

    encode_layer = NNEncLayer(opt) 

    boost_layer = PriorBoostLayer(opt) 

    non_gray_mask = NonGrayMaskLayer() 

    net = ColorizationModel() 

    net_opt = nn.Adam(net.trainable_params(), learning_rate=opt.learning_rate) 

    net_with_criterion = NetLoss(net) 

    scale_sense = nn.FixedLossScaleUpdateCell(1) 

    my_train_one_step_cell_for_net = nn.TrainOneStepWithLossScaleCell(net_with_criterion, net_opt, 

                                                                      scale_sense=scale_sense) 

    colormodel = ColorModel(my_train_one_step_cell_for_net) 

    colormodel.set_train() 

算法的主要原理是,将一张LAB格式灰度图片的L通道,输入模型进行推理,推理出其AB通道,最后将原始的L通道和推理出的AB通道结合起来,得到一张上色的图片。

cke_8751.png

一般常见的图片格式是RGB,有三个通道分别表示红色、绿色、蓝色。三个颜色组合出各种不同的颜色。而LAB图片格式的L通道表示图像的亮度,取值范围为0到100,颜色越大表示颜色越亮。AB的值域都是从-128到+128,A代表从绿色到红色的分量,B代表从蓝色到黄色的分量。

cke_10600.png

数据准备

本案例使用ImageNet数据集作为训练集和测试集。可以在官网下载。训练集中包含1000个类别,总计大约120万张图片,测试集中包含5万图片。

152G 建议通过bt下载。

cke_14286.png

下载完成后需要解压缩。

cke_16433.png

训练集可视化

import os 

import argparse 

from tqdm import tqdm 

import numpy as np 

import matplotlib.pyplot as plt 

import mindspore 

from src.process_datasets.data_generator import ColorizationDataset 

  

  

#加载参数 

parser = argparse.ArgumentParser() 

parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset') 

parser.add_argument('--batch_size', type=int, default=4) 

parser.add_argument('--num_parallel_workers', type=int, default=1) 

parser.add_argument('--shuffle', type=bool, default=True) 

args = parser.parse_args(args=[]) 

plt.figure() 

  

#加载数据集 

dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers) 

data = dataset.run() 

show_data = next(data.create_tuple_iterator()) 

show_images_original, _ = show_data 

show_images_original = show_images_original.asnumpy() 

#循环处理 

for i in range(1, 5): 

    plt.subplot(1, 4, i) 

    temp = show_images_original[i-1] 

    temp = np.clip(temp, 0, 1) 

    plt.imshow(temp) 

    plt.axis("off") 

    plt.subplots_adjust(wspace=0.05, hspace=0) 

会提示缺少依赖

cke_27705.png

cke_29875.png

可视化的结果如下:代码中缺少plt.show(),需要加上。

cke_32049.png

cke_34325.png

接下来是训练模型,在src目录下的train.py

cke_36654.png

同样会报缺少依赖

cke_40585.png

补齐依赖后继续

cke_43072.png

再次运行仍然报错,

查看代码,确定device_id默认值为1,也就是第二张卡,由于只有一张GPU卡,所以会报错,建议这里默认值修改成0

cke_52903.png

修改后继续

cke_55595.png

GPU的12G的显存,batch size选128 显存不够用了。降低到64试试。

cke_58231.png

查看显卡状态,功率200w 显存使用接近12G.

cke_60976.png

跑了大概12个小时,还是没跑完。看了单卡还是太慢了。

cke_66751.png

cke_63828.png

保存的模型文件都有22G了

cke_69744.png

停止训练,选取最新的权重进行推理。

cke_73046.png

直接使用infer.py, 看了数据集读取有点问题。

下载的数据集val下面是没有文件夹的,增加一个文件夹。

cke_80278.png

同样会遇到device id的问题,默认值修改为0.

cke_76434.png

最终着色的结果,中间的为mindspore预训练模型推理的结果,最右侧为本次训练模型的推理结果。

cke_83561.png

cke_87048.png

==,好像哪不对,输入应该是灰度图片。要先处理下原图。

cke_90547.png

cke_94095.png

我觉得红色的青苹果也不是不可以。。。

主要是我训练时间太短,模型准确度和预训练模型还是有点差距。