【MindSpore易点通】应用实践系列—ResNet50图像分类精讲(连载二)
【MindSpore易点通】应用实践系列—ResNet50图像分类精讲(连载二)
分析完网络,接下来就是数据了。
官网教程用的是CIFAR-10数据集,既然是应用实践,10个类别有点少了,就选用CIFAR-100。
CIFAR-100和CIFAR-10的数据结构上基本一致。
相比于CIFAR-10,CIFAR-100多了一个超类,但是实际上这个分类并没有在训练中用到。

CIFAR-100 python version
下载地址:
http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

meta是index对应的具体的名称。
train是训练集
test是测试集
from mindvision.classification.dataset import Cifar100
# 数据集根目录
data_dir = "./dataset"
dataset_train = Cifar100(path=data_dir, split='train', batch_size=6, resize=32, download=False)
ds_train = dataset_train.run()
step_size = ds_train.get_dataset_size()
print(ds_train.get_dataset_size())
dataset_val = Cifar100(path=data_dir, split='test', batch_size=6, resize=32, download=False)
ds_val = dataset_val.run()
print(ds_val.get_dataset_size())
import numpy as np
import matplotlib.pyplot as plt
data = next(ds_train.create_dict_iterator())
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print(f"Image shape: {images.shape}, Label: {labels}")
plt.figure()
for i in range(1, 7):
plt.subplot(2, 3, i)
image_trans = np.transpose(images[i - 1], (1, 2, 0))
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2023, 0.1994, 0.2010])
image_trans = std * image_trans + mean
image_trans = np.clip(image_trans, 0, 1)
plt.title(f"{dataset_train.index2label[labels[i - 1]]}")
plt.imshow(image_trans)
plt.axis("off")
plt.show()
上面是样例代码,里面直接用到了mindvision里面的Cifar100。
如果设置download=True可以直接下载相关数据集。我这边因为事先下载的数据集,所以设定为False。
mindvision封装了数据读取,我们可以自己看下如何读取文件。
如下是Cifar10的数据:
- data -- a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image.
- labels -- a list of 10000 numbers in the range 0-9. The number at index i indicates the label of the _i_th image in the array data.
The binary version of the CIFAR-100 is just like the binary version of the CIFAR-10, except that each image has two label bytes (coarse and fine) and 3072 pixel bytes, so the binary files look like this:
<1 x coarse label><1 x fine label><3072 x pixel>
...
<1 x coarse label><1 x fine label><3072 x pixel>
3072 个点,分别是RGB对应的值,刚好等于3*32*32.
import numpy as np
import matplotlib.pyplot as plt
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
a = unpickle(r'./dataset/cifar-100-python/train')
lables = unpickle(r'./dataset/cifar-100-python/meta')
print(a.keys())
print(lables.keys())
img = a[b'data'][0]
fine_label = a[b'fine_labels'][0]
coarse_label = a[b'coarse_labels'][0]
print(fine_label)
print(coarse_label)
print(lables[b'fine_label_names'][fine_label])
print(lables[b'coarse_label_names'][coarse_label])
imgrgb = img.reshape(3, 32, 32)
imgrgb = np.transpose(imgrgb, (1, 2, 0))
plt.imshow(imgrgb)
plt.show()


32*32 放大了人都没法识别了,缩小了稍微好点。

以上就是数据集处理。