代码
mindspore安装及案例测试

mindspore安装及案例测试

mindspore安装及案例测试

1、mindspore安装

在官网选择对应的版本,电脑有GPU的可以安装Cuda,基础版通过pip安装cpu

验证是否安装成功

 python -c "import mindspore;mindspore.set_context(device_target='CPU');mindspore.run_check()"

2、mindspore案例测试

from mindvision.dataset import Mnist  
  
download_train = Mnist(path="./mnist",split="train",batch_size=32,shuffle=True,resize=32,download=True)  
download_eval = Mnist(path="./mnist",split="test",batch_size=32,shuffle=True,resize=32,download=True)  
  
dataset_train = download_train.run()  
dataset_eval=download_eval.run()

from mindvision.classification.models import lenet  
  
network = lenet(num_classes=10)

import mindspore.nn as nn  
from mindspore.train import Model  
  
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')  
  
net_opt = nn.Momentum(network.trainable_params(),learning_rate=0.01,momentum=0.9)

from mindspore.train.callback import ModelCheckpoint,CheckpointConfig  
  
config_ck = CheckpointConfig(save_checkpoint_steps=1875,keep_checkpoint_max=10)  
  
ckpoint = ModelCheckpoint(prefix="lenet",directory="./lenet",config=config_ck)

from mindvision.engine.callback import LossMonitor  
  
model = Model(network,loss_fn=net_loss,optimizer=net_opt,metrics={'acc'})  
  
model.train(1,dataset_train,callbacks={ckpoint,LossMonitor(0.01)})

acc = model.eval(dataset_eval)  
print("{}".format(acc))

from mindspore import load_checkpoint,load_param_into_net  
  
param_dict =load_checkpoint("./lenet/lenet-1_1875.ckpt")  
  
load_param_into_net(network,param_dict)

import numpy as np  
from mindspore import Tensor  
import matplotlib.pyplot as plt  
mnist= Mnist("./mnist",split="test",batch_size=6,resize=32)  
dataset_infer =mnist.run()  
de_test =dataset_infer.create_dict_iterator()  
data = next(de_test)  
images= data["image"].asnumpy()  
labels= data["label"].asnumpy()  
plt.figure()  
for i in range(1,7):  
    plt.subplot(2,3,i)  
    plt.imshow(images[i-1][0],interpolation="None",cmap="gray")  
plt.show()  
  
output = model.predict(Tensor(data['image']))  
predicted=np.argmax(output.asnumpy(),axis=1)  
print(f'Predicted:"{predicted}",Actual:"{labels}"')