代码
【MindSpore易点通】应用实践系列—ResNet50图像分类精讲(连载二)

【MindSpore易点通】应用实践系列—ResNet50图像分类精讲(连载二)

【MindSpore易点通】应用实践系列—ResNet50图像分类精讲(连载二)

分析完网络,接下来就是数据了。

官网教程用的是CIFAR-10数据集,既然是应用实践,10个类别有点少了,就选用CIFAR-100。

CIFAR-100和CIFAR-10的数据结构上基本一致。

相比于CIFAR-10,CIFAR-100多了一个超类,但是实际上这个分类并没有在训练中用到。

cke_8191.png

CIFAR-100 python version

下载地址:

http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

cke_11576.png

meta是index对应的具体的名称。

train是训练集

test是测试集

from mindvision.classification.dataset import Cifar100

# 数据集根目录
data_dir = "./dataset"

dataset_train = Cifar100(path=data_dir, split='train', batch_size=6, resize=32, download=False)
ds_train = dataset_train.run()
step_size = ds_train.get_dataset_size()
print(ds_train.get_dataset_size())
dataset_val = Cifar100(path=data_dir, split='test', batch_size=6, resize=32, download=False)
ds_val = dataset_val.run()
print(ds_val.get_dataset_size())

import numpy as np
import matplotlib.pyplot as plt

data = next(ds_train.create_dict_iterator())

images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print(f"Image shape: {images.shape}, Label: {labels}")

plt.figure()
for i in range(1, 7):
    plt.subplot(2, 3, i)
    image_trans = np.transpose(images[i - 1], (1, 2, 0))
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])
    image_trans = std * image_trans + mean
    image_trans = np.clip(image_trans, 0, 1)
    plt.title(f"{dataset_train.index2label[labels[i - 1]]}")
    plt.imshow(image_trans)
    plt.axis("off")
plt.show()

上面是样例代码,里面直接用到了mindvision里面的Cifar100。

如果设置download=True可以直接下载相关数据集。我这边因为事先下载的数据集,所以设定为False。

mindvision封装了数据读取,我们可以自己看下如何读取文件。

如下是Cifar10的数据:

  • data -- a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image.
  • labels -- a list of 10000 numbers in the range 0-9. The number at index i indicates the label of the _i_th image in the array data.

The binary version of the CIFAR-100 is just like the binary version of the CIFAR-10, except that each image has two label bytes (coarse and fine) and 3072 pixel bytes, so the binary files look like this:

<1 x coarse label><1 x fine label><3072 x pixel>

...

<1 x coarse label><1 x fine label><3072 x pixel>

3072 个点,分别是RGB对应的值,刚好等于3*32*32.

import numpy as np
import matplotlib.pyplot as plt
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
a = unpickle(r'./dataset/cifar-100-python/train')
lables = unpickle(r'./dataset/cifar-100-python/meta')
print(a.keys())
print(lables.keys())
img = a[b'data'][0]
fine_label = a[b'fine_labels'][0]
coarse_label = a[b'coarse_labels'][0]
print(fine_label)
print(coarse_label)
print(lables[b'fine_label_names'][fine_label])
print(lables[b'coarse_label_names'][coarse_label])

imgrgb = img.reshape(3, 32, 32)
imgrgb = np.transpose(imgrgb, (1, 2, 0))

plt.imshow(imgrgb)
plt.show()

cke_29938.png

cke_35085.png

32*32 放大了人都没法识别了,缩小了稍微好点。

cke_44862.png

以上就是数据集处理。