MindSpore为你添彩-自动着色算法之Colorization实践
MindSpore为你添彩-自动着色算法之Colorization实践
自动着色算法之Colorization
模型简介
Colorization算法是来自加里福利亚大学的一项研究,采用的是CNN的结构。该算法可以实现灰度图像的自动着色,由Richard Zhang等人在论文Colorful Image Colorization中提出,并发表在2016年的ECCV会议中。该模型由8个conv层组成,每个conv层由2个或3个重复的卷积层和ReLU层组成,后面跟着一个BatchNorm层。网络中不包含池化层。
网络特点
设计了一个合适的损失函数来处理着色问题中的多模不确定性,维持了颜色的多样性。
将图像着色任务转化为一个自监督表达学习的任务。
在一些基准模型上获得了最好的效果。
encode_layer = NNEncLayer(opt)
boost_layer = PriorBoostLayer(opt)
non_gray_mask = NonGrayMaskLayer()
net = ColorizationModel()
net_opt = nn.Adam(net.trainable_params(), learning_rate=opt.learning_rate)
net_with_criterion = NetLoss(net)
scale_sense = nn.FixedLossScaleUpdateCell(1)
my_train_one_step_cell_for_net = nn.TrainOneStepWithLossScaleCell(net_with_criterion, net_opt,
scale_sense=scale_sense)
colormodel = ColorModel(my_train_one_step_cell_for_net)
colormodel.set_train()
算法的主要原理是,将一张LAB格式灰度图片的L通道,输入模型进行推理,推理出其AB通道,最后将原始的L通道和推理出的AB通道结合起来,得到一张上色的图片。

一般常见的图片格式是RGB,有三个通道分别表示红色、绿色、蓝色。三个颜色组合出各种不同的颜色。而LAB图片格式的L通道表示图像的亮度,取值范围为0到100,颜色越大表示颜色越亮。AB的值域都是从-128到+128,A代表从绿色到红色的分量,B代表从蓝色到黄色的分量。

数据准备
本案例使用ImageNet数据集作为训练集和测试集。可以在官网下载。训练集中包含1000个类别,总计大约120万张图片,测试集中包含5万图片。
152G 建议通过bt下载。

下载完成后需要解压缩。

训练集可视化
import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import mindspore
from src.process_datasets.data_generator import ColorizationDataset
#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--shuffle', type=bool, default=True)
args = parser.parse_args(args=[])
plt.figure()
#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers)
data = dataset.run()
show_data = next(data.create_tuple_iterator())
show_images_original, _ = show_data
show_images_original = show_images_original.asnumpy()
#循环处理
for i in range(1, 5):
plt.subplot(1, 4, i)
temp = show_images_original[i-1]
temp = np.clip(temp, 0, 1)
plt.imshow(temp)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0)
会提示缺少依赖


可视化的结果如下:代码中缺少plt.show(),需要加上。


接下来是训练模型,在src目录下的train.py

同样会报缺少依赖

补齐依赖后继续

再次运行仍然报错,
查看代码,确定device_id默认值为1,也就是第二张卡,由于只有一张GPU卡,所以会报错,建议这里默认值修改成0

修改后继续

GPU的12G的显存,batch size选128 显存不够用了。降低到64试试。

查看显卡状态,功率200w 显存使用接近12G.

跑了大概12个小时,还是没跑完。看了单卡还是太慢了。


保存的模型文件都有22G了

停止训练,选取最新的权重进行推理。

直接使用infer.py, 看了数据集读取有点问题。
下载的数据集val下面是没有文件夹的,增加一个文件夹。

同样会遇到device id的问题,默认值修改为0.

最终着色的结果,中间的为mindspore预训练模型推理的结果,最右侧为本次训练模型的推理结果。


==,好像哪不对,输入应该是灰度图片。要先处理下原图。


我觉得红色的青苹果也不是不可以。。。
主要是我训练时间太短,模型准确度和预训练模型还是有点差距。