# Running Mode [![View Source On Gitee](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.3/docs/mindspore/programming_guide/source_en/run.md) ## Overview There are three execution modes: single operator, common function, and network training model. > Note: This document is applicable to GPU and Ascend environments. ## Executing a Single Operator Execute a single operator and output the result. A code example is as follows: ```python import numpy as np import mindspore.nn as nn from mindspore import context, Tensor context.set_context(mode=context.GRAPH_MODE, device_target="GPU") conv = nn.Conv2d(3, 4, 3, bias_init='zeros') input_data = Tensor(np.ones([1, 3, 5, 5]).astype(np.float32)) output = conv(input_data) print(output.asnumpy()) ``` The output is as follows: ```text [[[[ 0.06022915 0.06149777 0.06149777 0.06149777 0.01145121] [ 0.06402162 0.05889071 0.05889071 0.05889071 -0.00933781] [ 0.06402162 0.05889071 0.05889071 0.05889071 -0.00933781] [ 0.06402162 0.05889071 0.05889071 0.05889071 -0.00933781] [ 0.02712326 0.02096302 0.02096302 0.02096302 -0.01119636]] [[-0.0258286 -0.03362969 -0.03362969 -0.03362969 -0.00799183] [-0.0513729 -0.06778982 -0.06778982 -0.06778982 -0.03168458] [-0.0513729 -0.06778982 -0.06778982 -0.06778982 -0.03168458] [-0.0513729 -0.06778982 -0.06778982 -0.06778982 -0.03168458] [-0.04186669 -0.07266843 -0.07266843 -0.07266843 -0.04836193]] [[-0.00840744 -0.03043237 -0.03043237 -0.03043237 0.00172079] [ 0.00401019 -0.03755453 -0.03755453 -0.03755453 -0.00851137] [ 0.00401019 -0.03755453 -0.03755453 -0.03755453 -0.00851137] [ 0.00401019 -0.03755453 -0.03755453 -0.03755453 -0.00851137] [ 0.00270888 -0.03718876 -0.03718876 -0.03718876 -0.03043662]] [[-0.00982172 0.02009856 0.02009856 0.02009856 0.03327979] [ 0.02529106 0.04035065 0.04035065 0.04035065 0.01782833] [ 0.02529106 0.04035065 0.04035065 0.04035065 0.01782833] [ 0.02529106 0.04035065 0.04035065 0.04035065 0.01782833] [ 0.01015155 0.00781826 0.00781826 0.00781826 -0.02884173]]]] ``` > Note: Due to random factors in weight initialization, the actual output results may be different, which is for reference only. ## Executing a Common Function Combine multiple operators into a function, execute these operators by calling the function, and output the result. A code example is as follows: ```python import numpy as np from mindspore import context, Tensor import mindspore.ops as ops context.set_context(mode=context.GRAPH_MODE, device_target="GPU") def add_func(x, y): z = ops.add(x, y) z = ops.add(z, x) return z x = Tensor(np.ones([3, 3], dtype=np.float32)) y = Tensor(np.ones([3, 3], dtype=np.float32)) output = add_func(x, y) print(output.asnumpy()) ``` The output is as follows: ```text [[3. 3. 3.] [3. 3. 3.] [3. 3. 3.]] ``` ## Executing a Network Model The [Model API](https://www.mindspore.cn/docs/api/en/r1.3/api_python/mindspore.html#mindspore.Model) of MindSpore is an advanced API used for training and validation. Layers with the training or inference function can be combined into an object. The training, inference, and prediction functions can be implemented by calling the train, eval, and predict APIs, respectively. > MindSpore does not support the use of multiple threads for training, inference, and prediction functions. You can transfer the initialized Model APIs such as the network, loss function, and optimizer as required. You can also configure amp_level to implement mixed precision and configure metrics to implement model evaluation. > Executing the network model will generate a `kernel_meta` directory under the execution directory, and save the operator cache files generated by network compilation to this directory during execution, including `.o`, `.info` and `.json` files. If the user executes the same network model again, or only some changes are made, MindSpore will automatically call the reusable operator cache file in the `kernel_meta` directory, which significantly reduces network compilation time and improves execution performance. For details, please refer to [Incremental Operator Build](https://www.mindspore.cn/docs/programming_guide/en/r1.3/incremental_operator_build.html) Before executing the network, download and unzip the required dataset to the specified directory in jupyter notebook: ```bash mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte --no-check-certificate wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte --no-check-certificate wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte --no-check-certificate wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte --no-check-certificate !tree ./datasets/MNIST_Data ``` ```text ./datasets/MNIST_Data ├── test │ ├── t10k-images-idx3-ubyte │ └── t10k-labels-idx1-ubyte └── train ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte 2 directories, 4 files ``` ### Executing a Training Model Call the train API of Model to implement training. A code example is as follows: ```python import os import mindspore.dataset.vision.c_transforms as CV from mindspore.dataset.vision import Inter import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as CT import mindspore.nn as nn from mindspore import context, Model from mindspore import dtype as mstype from mindspore.common.initializer import Normal from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test """ # define dataset mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = CT.TypeCast(mstype.int32) # apply map operations on images mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """ Lenet network Args: num_class (int): Num classes. Default: 10. num_channel (int): Num channels. Default: 1. Returns: Tensor, output tensor Examples: >>> LeNet(num_class=10) """ def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") model_path = "./models/ckpt/mindspore_run/" os.system("rm -rf {0}*.ckpt {0}*.meta {0}*.pb".format(model_path)) ds_train_path = "./datasets/MNIST_Data/train/" ds_train = create_dataset(ds_train_path, 32) network = LeNet5(10) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=model_path, config=config_ck) model = Model(network, net_loss, net_opt) print("============== Starting Training ==============") model.train(1, ds_train, callbacks=[LossMonitor(375), ckpoint_cb], dataset_sink_mode=True) ``` ```text ============== Starting Training ============== epoch: 1 step: 375, loss is 2.2898183 epoch: 1 step: 750, loss is 2.2777305 epoch: 1 step: 1125, loss is 0.27802905 epoch: 1 step: 1500, loss is 0.032973606 epoch: 1 step: 1875, loss is 0.06105463 ``` > For details about how to obtain the MNIST dataset used in the example, see [Downloading the Dataset](https://www.mindspore.cn/docs/programming_guide/en/r1.3/quick_start/quick_start.html#downloading-the-dataset). > Use the PyNative mode for debugging, including the execution of single operator, common function, and network training model. For details, see [Debugging in PyNative Mode](https://www.mindspore.cn/docs/programming_guide/en/r1.3/debug_in_pynative_mode.html). > To use free control loop iterations, traversing data sets, etc., you can refer to the "Customizing a Training Cycle" part of the official website programming guide "[Training](https://www.mindspore.cn/docs/programming_guide/en/r1.3/train.html#customizing-a-training-cycle)". ### Executing an Inference Model Call the eval API of Model to implement inference. To facilitate model evaluation, you can set metrics when the Model API is initialized. Metrics are used to evaluate models. Common metrics include Accuracy, Fbeta, Precision, Recall, and TopKCategoricalAccuracy. Generally, the comprehensive model quality cannot be evaluated by one model metric. Therefore, multiple metrics are often used together to evaluate the model. Common built-in evaluation metrics are as follows: - `Accuracy`: evaluates a classification model. Generally, accuracy refers to the percentage of results correctly predicted by the model to all results. Formula: $$Accuracy = (TP + TN)/(TP + TN + FP + FN)$$ - `Precision`: percentage of correctly predicted positive results to all predicted positive results. Formula: $$Precision = TP/(TP + FP)$$ - `Recall`: percentage of correctly predicted positive results to all actual positive results. Formula: $$Recall = TP/(TP + FN)$$ - `Fbeta`: harmonic mean of precision and recall. Formula: $$F_\beta = (1 + \beta^2) \cdot \frac{precisiont \cdot recall}{(\beta^2 \cdot precision) + recall}$$ - `TopKCategoricalAccuracy`: calculates the top K categorical accuracy. A code example is as follows: ```python import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as CT import mindspore.dataset.vision.c_transforms as CV import mindspore.nn as nn from mindspore import context, Model, load_checkpoint, load_param_into_net from mindspore import dtype as mstype from mindspore.common.initializer import Normal from mindspore.dataset.vision import Inter from mindspore.nn import Accuracy, Precision class LeNet5(nn.Cell): """ Lenet network Args: num_class (int): Num classes. Default: 10. num_channel (int): Num channels. Default: 1. Returns: Tensor, output tensor Examples: >>> LeNet(num_class=10) """ def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test """ # define dataset mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = CT.TypeCast(mstype.int32) # apply map operations on images mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") model_path = "./models/ckpt/mindspore_run/" ds_eval_path = "./datasets/MNIST_Data/test/" network = LeNet5(10) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") repeat_size = 1 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), "Precision": Precision()}) print("============== Starting Testing ==============") param_dict = load_checkpoint(model_path+"checkpoint_lenet-1_1875.ckpt") load_param_into_net(network, param_dict) ds_eval = create_dataset(ds_eval_path, 32, repeat_size) acc = model.eval(ds_eval, dataset_sink_mode=True) print("============== {} ==============".format(acc)) ``` ```text ============== Starting Testing ============== ============== {'Accuracy': 0.960136217948718, 'Precision': array([0.95763547, 0.98059965, 0.99153439, 0.93333333, 0.97322348, 0.99385749, 0.98502674, 0.93179724, 0.8974359 , 0.97148676])} ============== ``` In the preceding information: - `load_checkpoint`: loads the checkpoint model parameter file and returns a parameter dictionary. - `checkpoint_lenet-1_1875.ckpt`: name of the saved checkpoint model file. - `load_param_into_net`: loads parameters to the network. > For details about how to save the `checkpoint_lenet-1_1875.ckpt` file, see [Training the Network](https://www.mindspore.cn/docs/programming_guide/en/r1.3/quick_start/quick_start.html#training-the-network).