昇思应用案例 | Colorization实现灰度图像的自动着色
昇思应用案例 | Colorization实现灰度图像的自动着色
作者:Yeats_Liao |来源:CSDN
本案例对Colorful Image Colorization文中提出的模型进行了详细的解释,向读者完整地展现了该算法的流程,分析了Colorization在着色方面的优势和存在的不足。如需查看详细代码,可参考MindSpore Vision套件。
01
环节准备
进入昇思MindSpore官网(https://www.mindspore.cn/),点击上方的安装。

打开一个Terminal,输入安装命令。
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
再点击侧边栏中的Clone a Repository,输入代码。
https://github.com/mindspore-courses/applications.git

02
自动着色算法之Colorization****简介
当桃乐丝在1939年的电影《绿野仙踪》中走进奥兹国时,从黑白到鲜艳的色彩的转变使它成为电影史上最令人叹为观止的时刻之一。毫无疑问,颜色是一种有效的表达工具,但它们通常是有代价的。在制作现代动画电影和漫画时,图像着色是最费力和昂贵的阶段之一。自动着色过程可以帮助减少制作漫画或动画电影所需的成本和时间。
Colorization算法是来自加里福利亚大学的一项研究,采用的是CNN的结构。该算法可以实现灰度图像的自动着色,由Richard Zhang等人在论文Colorful Image Colorization中提出,并发表在2016年的ECCV会议中。该模型由8个conv层组成,每个conv层由2个或3个重复的卷积层和ReLU层组成,后面跟着一个BatchNorm层。网络中不包含池化层。
网络特点
1、设计了一个合适的损失函数来处理着色问题中的多模不确定性,维持了颜色的多样性。
2、将图像着色任务转化为一个自监督表达学习的任务。
3、在一些基准模型上获得了最好的效果。
03
数据处理
开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore Vision套件。
1.数据准备
本案例使用ImageNet数据集作为训练集和测试集。请在官网(https://www.image-net.org/)下载。训练集中包含1000个类别,总计大约120万张图片,测试集中包含5万图片。
解压后的数据集目录结构如下:
.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
└── val/
2.训练集可视化
import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import mindspore
from src.process_datasets.data_generator import ColorizationDataset
#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--shuffle', type=bool, default=True)
args = parser.parse_args(args=[])
plt.figure()
#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers)
data = dataset.run()
show_data = next(data.create_tuple_iterator())
show_images_original, _ = show_data
show_images_original = show_images_original.asnumpy()
#循环处理
for i in range(1, 5):
plt.subplot(1, 4, i)
temp = show_images_original[i-1]
temp = np.clip(temp, 0, 1)
plt.imshow(temp)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0)

3.构建网络
处理完数据后进行网络的搭建,Colorization的网络结构较为简单,采用CNN的网络结构。具体结构如下图所示。

网络的详细配置为:

其中X输出的空间分辨率,C输出的通道数;S计算步幅,大于1表示卷积后下采样,小于1表示卷积前上采样;D内核扩张;Sa在所有前一层的累积步数(积于前一层的所有步数);相对于输入的层的有效膨胀(层膨胀乘以累积步幅);BN层后是否使用BatchNorm层;L表示是否施加了1x1的卷积和交叉熵损失层。
4.损失函数

分类再平衡。

分类概率到点估计。

class NetLoss(nn.Cell):
"""连接网络和损失"""
def __init__(self, net):
super(NetLoss, self).__init__(auto_prefix=True)
self.net = net
self.loss = nn.CrossEntropyLoss(reduction='none')
def construct(self, images, targets, boost, mask):
""" build network """
outputs = self.net(images)
boost_nongray = boost * mask
squeeze = mindspore.ops.Squeeze(1)
boost_nongray = squeeze(boost_nongray)
result = self.loss(outputs, targets)
result_loss = (result * boost_nongray).mean()
return result_loss
04
模型实现
昇思MindSpore要求将损失函数、优化器等操作也看做nn.Cell的子类,所以我们可以自定义Color类,将网络和loss连接起来。
class ColorModel(nn.Cell):
"""定义Colorization网络"""
def __init__(self, my_train_one_step_cell_for_net):
super(ColorModel, self).__init__(auto_prefix=True)
self.my_train_one_step_cell_for_net = my_train_one_step_cell_for_net
def construct(self, result, targets, boost, mask):
loss = self.my_train_one_step_cell_for_net(result, targets, boost,
mask)
return loss
1.算法流程

**2.**模型训练
实例化损失函数,优化器,使用Model接口编译网络,开始训练。
import argparse
import os
from tqdm import tqdm
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import ops
import numpy as np
import matplotlib.pyplot as plt
from src.utils.utils import PriorBoostLayer, NNEncLayer, NonGrayMaskLayer, decode
from src.model.model import ColorizationModel
from src.model.colormodel import ColorModel
from src.process_datasets.data_generator import ColorizationDataset
from src.losses.loss import NetLoss
import warnings
warnings.filterwarnings('ignore')
#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--device_target',
default='GPU',
choices=['CPU', 'GPU', 'Ascend'],
type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--image_dir',
type=str,
default='./dataset/train',
help='path to dataset')
parser.add_argument('--checkpoint_dir',
type=str,
default='./checkpoints',
help='path for saving trained model')
parser.add_argument('--test_dirs',
type=str,
default='./images',
help='path for saving trained model')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--num_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.5e-4)
parser.add_argument('--save_step',
type=int,
default=200,
help='step size for saving trained models')
args = parser.parse_args(args=[])
if context.get_context('device_id') != args.device_id:
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
encode_layer = NNEncLayer(args)
boost_layer = PriorBoostLayer(args)
non_gray_mask = NonGrayMaskLayer()
#网络实例化
net = ColorizationModel()
#设置优化器
net_args = nn.Adam(net.trainable_params(), learning_rate=args.learning_rate)
#实例化NetLoss
net_with_criterion = NetLoss(net)
#实例化TrainOneStepWithLossScaleCell
scale_sense = nn.FixedLossScaleUpdateCell(1)
myTrainOneStepCellForNet = nn.TrainOneStepWithLossScaleCell(
net_with_criterion, net_args, scale_sense=scale_sense)
colormodel = ColorModel(myTrainOneStepCellForNet)
colormodel.set_train()
#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle,
args.num_parallel_workers)
data = dataset.run().create_tuple_iterator()
for epoch in range(args.num_epochs):
iters = 0
#为每轮训练读入数据
for images, img_ab in tqdm(data):
images = ops.expand_dims(images, 1)
encode, max_encode = encode_layer.forward(img_ab)
targets = mindspore.Tensor(max_encode, dtype=mindspore.int32)
boost = mindspore.Tensor(boost_layer.forward(encode),
dtype=mindspore.float32)
mask = mindspore.Tensor(non_gray_mask.forward(img_ab),
dtype=mindspore.float32)
net_loss = colormodel(images, targets, boost, mask)
#输出训练数据
print('[%d/%d]\tLoss_net:: %.4f' % (epoch + 1, args.num_epochs, net_loss[0]))
#中间保存训练结果
if iters % args.save_step == 0:
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
mindspore.save_checkpoint(
net,
os.path.join(args.checkpoint_dir, 'net' + str(epoch + 1) + '_' +
str(iters) + '.ckpt'))
img_ab_313 = net(images)
out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
color_img = decode(images, img_ab_313, args.resource)
if not os.path.exists(args.test_dirs):
os.makedirs(args.test_dirs)
plt.imsave(
args.test_dirs + '/' + str(epoch + 1) + '_' + str(iters) +
'%s_infer.png', color_img)
iters = iters + 1
3.模型推理
运行下面代码,将一张灰度图像输入到网络中,即可生成具有合理色彩的图像。
import argparse
import os
import matplotlib.pyplot as plt
import mindspore
import numpy as np
from mindspore import (context, load_checkpoint, load_param_into_net, ops)
from mindspore.train.model import Model
from tqdm import tqdm
from src.model.model import ColorizationModel
from src.process_datasets.data_generator import ColorizationDataset
from src.utils.utils import decode
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='./dataset/val')
parser.add_argument('--ckpt_path', type=str, default='./checkpoints/net44_1600.ckpt')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--device_target', default='GPU', choices=['CPU', 'GPU', 'Ascend'], type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--infer_dirs', default='./dataset/output', type=str)
args = parser.parse_args(args=[])
mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
#实例化网络
net = ColorizationModel()
#加载参数
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
colorizer = Model(net)
dataset = ColorizationDataset(args.img_path, 1, prob=0)
data = dataset.run().create_tuple_iterator()
iters = 0
if not os.path.exists(args.infer_dirs):
os.makedirs(args.infer_dirs)
#循环处理图像
for images, img_ab in tqdm(data):
images = ops.expand_dims(images, 1)
img_ab_313 = colorizer.predict(images)
out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
color_img = decode(images, img_ab_313, args.resource)
plt.imsave(args.infer_dirs+'/'+str(iters)+'_infer.png', color_img)
iters = iters + 1


往期回顾