使用贝叶斯神经网络实现图片分类应用

下载Notebook下载样例代码查看源文件

深度学习模型具有强大的拟合能力,而贝叶斯理论具有很好的可解释能力。MindSpore深度概率编程(MindSpore Probability)将深度学习和贝叶斯学习结合,通过设置网络权重为分布、引入隐空间分布等,可以对分布进行采样前向传播,由此引入了不确定性,从而增强了模型的鲁棒性和可解释性。

本章将详细介绍深度概率编程中的贝叶斯神经网络在MindSpore上的应用。在动手进行实践之前,确保,你已经正确安装了MindSpore 0.7.0-beta及其以上版本。

本例面向GPU或Ascend 910 AI处理器平台,你可以在这里下载完整的样例代码:https://gitee.com/mindspore/mindspore/tree/r1.7/tests/st/probability/bnn_layers

贝叶斯神经网络目前只支持图模式,需要在代码中设置context.set_context(mode=context.GRAPH_MODE)

使用贝叶斯神经网络

贝叶斯神经网络是由概率模型和神经网络组成的基本模型,它的权重不再是一个确定的值,而是一个分布。本例介绍了如何使用MDP中的bnn_layers模块实现贝叶斯神经网络,并利用贝叶斯神经网络实现一个简单的图片分类功能,整体流程如下:

  1. 处理MNIST数据集;

  2. 定义贝叶斯LeNet网络;

  3. 定义损失函数和优化器;

  4. 加载数据集并进行训练。

环境准备

设置训练模式为图模式,计算平台为GPU。

[1]:
from mindspore import context

context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")

数据准备

下载数据集

以下示例代码将MNIST数据集下载并解压到指定位置。

[ ]:
import os
import requests

requests.packages.urllib3.disable_warnings()

def download_dataset(dataset_url, path):
    filename = dataset_url.split("/")[-1]
    save_path = os.path.join(path, filename)
    if os.path.exists(save_path):
        return
    if not os.path.exists(path):
        os.makedirs(path)
    res = requests.get(dataset_url, stream=True, verify=False)
    with open(save_path, "wb") as f:
        for chunk in res.iter_content(chunk_size=512):
            if chunk:
                f.write(chunk)
    print("The {} file is downloaded and saved in the path {} after processing".format(os.path.basename(dataset_url), path))

train_path = "datasets/MNIST_Data/train"
test_path = "datasets/MNIST_Data/test"

download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte", train_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte", train_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte", test_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte", test_path)

下载的数据集文件的目录结构如下:

./datasets/MNIST_Data
├── test
│   ├── t10k-images-idx3-ubyte
│   └── t10k-labels-idx1-ubyte
└── train
    ├── train-images-idx3-ubyte
    └── train-labels-idx1-ubyte

定义数据集增强方法

MNIST数据集的原始训练数据集是60000张\(28\times28\)像素的单通道数字图片,本次训练用到的含贝叶斯层的LeNet5网络接收到训练数据的张量为(32,1,32,32),通过自定义create_dataset函数将原始数据集增强为适应训练要求的数据,具体的增强操作解释可参考初学入门

[3]:
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dataset as ds

def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define some parameters needed for data enhancement and rough justification
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    c_trans = [
        CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR),
        CV.Rescale(rescale_nml, shift_nml),
        CV.Rescale(rescale, shift),
        CV.HWC2CHW()
    ]
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to a dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers)

    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds

定义贝叶斯神经网络

在经典LeNet5网络中,数据经过如下计算过程:卷积1->激活->池化->卷积2->激活->池化->降维->全连接1->全连接2->全连接3。
本例中将引入概率编程方法,利用bnn_layers模块将卷层和全连接层改造成贝叶斯层
[4]:
import mindspore.nn as nn
from mindspore.nn.probability import bnn_layers
import mindspore.ops as ops
from mindspore import dtype as mstype


class BNNLeNet5(nn.Cell):
    def __init__(self, num_class=10):
        super(BNNLeNet5, self).__init__()
        self.num_class = num_class
        self.conv1 = bnn_layers.ConvReparam(1, 6, 5, stride=1, padding=0, has_bias=False, pad_mode="valid")
        self.conv2 = bnn_layers.ConvReparam(6, 16, 5, stride=1, padding=0, has_bias=False, pad_mode="valid")
        self.fc1 = bnn_layers.DenseReparam(16 * 5 * 5, 120)
        self.fc2 = bnn_layers.DenseReparam(120, 84)
        self.fc3 = bnn_layers.DenseReparam(84, self.num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

network = BNNLeNet5(num_class=10)
for layer in network.trainable_params():
    print(layer.name)
conv1.weight_posterior.mean
conv1.weight_posterior.untransformed_std
conv2.weight_posterior.mean
conv2.weight_posterior.untransformed_std
fc1.weight_posterior.mean
fc1.weight_posterior.untransformed_std
fc1.bias_posterior.mean
fc1.bias_posterior.untransformed_std
fc2.weight_posterior.mean
fc2.weight_posterior.untransformed_std
fc2.bias_posterior.mean
fc2.bias_posterior.untransformed_std
fc3.weight_posterior.mean
fc3.weight_posterior.untransformed_std
fc3.bias_posterior.mean
fc3.bias_posterior.untransformed_std

打印信息表明,使用bnn_layers模块构建的LeNet网络,其卷积层和全连接层均为贝叶斯层。

定义损失函数和优化器

接下来需要定义损失函数(Loss)和优化器(Optimizer)。损失函数是深度学习的训练目标,也叫目标函数,可以理解为神经网络的输出(Logits)和标签(Labels)之间的距离,是一个标量数据。

常见的损失函数包括均方误差、L2损失、Hinge损失、交叉熵等等。图像分类应用通常采用交叉熵损失(CrossEntropy)。

优化器用于神经网络求解(训练)。由于神经网络参数规模庞大,无法直接求解,因而深度学习中采用随机梯度下降算法(SGD)及其改进算法进行求解。MindSpore封装了常见的优化器,如SGDAdamMomemtum等等。本例采用Adam优化器,通常需要设定两个参数,学习率(learning_rate)和权重衰减项(weight_decay)。

MindSpore中定义损失函数和优化器的代码样例如下:

[5]:
import mindspore.nn as nn

# loss function definition
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

# optimization definition
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)

训练网络

贝叶斯神经网络的训练过程与DNN基本相同,唯一不同的是将WithLossCell替换为适用于BNN的WithBNNLossCell。除了backboneloss_fn两个参数之外,WithBNNLossCell增加了dnn_factorbnn_factor两个参数。这两个参数是用来平衡网络整体损失和贝叶斯层的KL散度的,防止KL散度的值过大掩盖了网络整体损失。

  • dnn_factor是由损失函数计算得到的网络整体损失的系数。

  • bnn_factor是每个贝叶斯层的KL散度的系数。

构建模型训练函数train_model和模型验证函数validate_model

[6]:
def train_model(train_net, net, dataset):
    accs = []
    loss_sum = 0
    for _, data in enumerate(dataset.create_dict_iterator()):
        train_x = Tensor(data['image'].asnumpy().astype(np.float32))
        label = Tensor(data['label'].asnumpy().astype(np.int32))
        loss = train_net(train_x, label)
        output = net(train_x)
        log_output = ops.LogSoftmax(axis=1)(output)
        acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
        accs.append(acc)
        loss_sum += loss.asnumpy()

    loss_sum = loss_sum / len(accs)
    acc_mean = np.mean(accs)
    return loss_sum, acc_mean


def validate_model(net, dataset):
    accs = []
    for _, data in enumerate(dataset.create_dict_iterator()):
        train_x = Tensor(data['image'].asnumpy().astype(np.float32))
        label = Tensor(data['label'].asnumpy().astype(np.int32))
        output = net(train_x)
        log_output = ops.LogSoftmax(axis=1)(output)
        acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
        accs.append(acc)

    acc_mean = np.mean(accs)
    return acc_mean

执行训练。

[7]:
from mindspore.nn import TrainOneStepCell
from mindspore import Tensor
import numpy as np

net_with_loss = bnn_layers.WithBNNLossCell(network, criterion, dnn_factor=60000, bnn_factor=0.000001)
train_bnn_network = TrainOneStepCell(net_with_loss, optimizer)
train_bnn_network.set_train()

train_set = create_dataset('./datasets/MNIST_Data/train', 64, 1)
test_set = create_dataset('./datasets/MNIST_Data/test', 64, 1)

epoch = 10

for i in range(epoch):
    train_loss, train_acc = train_model(train_bnn_network, network, train_set)

    valid_acc = validate_model(network, test_set)

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tvalidation Accuracy: {:.4f}'.
          format(i+1, train_loss, train_acc, valid_acc))
Epoch: 1        Training Loss: 21444.8605       Training Accuracy: 0.8928       validation Accuracy: 0.9513
Epoch: 2        Training Loss: 9396.3887        Training Accuracy: 0.9536       validation Accuracy: 0.9635
Epoch: 3        Training Loss: 7320.2412        Training Accuracy: 0.9641       validation Accuracy: 0.9674
Epoch: 4        Training Loss: 6221.6970        Training Accuracy: 0.9685       validation Accuracy: 0.9731
Epoch: 5        Training Loss: 5450.9543        Training Accuracy: 0.9725       validation Accuracy: 0.9733
Epoch: 6        Training Loss: 4898.9741        Training Accuracy: 0.9754       validation Accuracy: 0.9767
Epoch: 7        Training Loss: 4505.7502        Training Accuracy: 0.9775       validation Accuracy: 0.9784
Epoch: 8        Training Loss: 4099.8783        Training Accuracy: 0.9797       validation Accuracy: 0.9791
Epoch: 9        Training Loss: 3795.2288        Training Accuracy: 0.9810       validation Accuracy: 0.9796
Epoch: 10       Training Loss: 3581.4254        Training Accuracy: 0.9823       validation Accuracy: 0.9773