mindspore.train.Model
- class mindspore.train.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level='O0', boost_level='O0', **kwargs)[source]
- High-Level API for training or inference. - Model groups layers into an object with training and inference features based on the arguments. - Note - If use mixed precision functions, need to set parameter optimizer at the same time, otherwise mixed precision functions do not take effect. When uses mixed precision functions, global_step in optimizer may be different from cur_step_num in Model. 
- After using custom_mixed_precision or auto_mixed_precision for precision conversion, it is not supported to perform the precision conversion again. If Model is used to train a converted network, amp_level need to be configured to - O0to avoid the duplicated accuracy conversion.
 - Parameters
- network (Cell) – A training or testing network. 
- loss_fn (Cell) – Objective function. If loss_fn is None, the network should contain the calculation of loss. Default: - None.
- optimizer (Cell) – Optimizer for updating the weights. If optimizer is None, the network needs to do backpropagation and update weights. Default: - None.
- metrics (Union[dict, set]) – A Dictionary or a set of metrics for model evaluation. eg: {'accuracy', 'recall'}. Default: - None.
- eval_network (Cell) – Network for evaluation. If not defined, network and loss_fn would be wrapped as eval_network . Default: - None.
- eval_indexes (list) – It is used when eval_network is defined. If eval_indexes is None by default, all outputs of the eval_network would be passed to metrics. If eval_indexes is set, it must contain three elements: the positions of loss value, predicted value and label in outputs of the eval_network. In this case, the loss value will be passed to the Loss metric, the predicted value and label will be passed to other metrics. - mindspore.train.Metric.set_indexes()is recommended instead of eval_indexes. Default:- None.
- amp_level (str) – - Option for argument level in - mindspore.amp.build_train_network(), level for mixed precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default:- "O0".- For details on amp_level , refer to - mindspore.amp.auto_mixed_precision().- The BatchNorm strategy can be changed by keep_batchnorm_fp32 settings in kwargs. keep_batchnorm_fp32 must be a bool. The loss scale strategy can be changed by loss_scale_manager setting in kwargs. loss_scale_manager should be a subclass of - mindspore.amp.LossScaleManager.
- boost_level (str) – - Option for argument level in mindspore.boost, level for boost mode training. Supports ["O0", "O1", "O2"]. Default: - "O0".- "O0": Do not change. 
- "O1": Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy. 
- "O2": Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%. 
 - If you want to config boost mode by yourself, you can set boost_config_dict as boost.py. In order for this function to work, you need to set the parameter optimizer, along with at least one of the parameter eval_network or performance metrics. - Notice: The current optimization enabled by default only applies to some networks, and not all networks can obtain the same benefits. It is recommended to enable this function on the Graph mode + Ascend platform, and for better acceleration, refer to - mindspore.boost.AutoBoostto configure boost_config_dict.
 
 - Examples - >>> from mindspore import nn >>> from mindspore.train import Model >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model.train_network >>> model.predict_network >>> model.eval_network >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> model.train(2, dataset) - build(train_dataset=None, valid_dataset=None, sink_size=- 1, epoch=1, sink_mode=True)[source]
- Build computational graphs and data graphs with the sink mode. - Warning - This is an experimental API that is subject to change or deletion. - Note - The interface builds the computational graphs, when the interface is executed first, 'Model.train' only performs the graphs execution. Pre-build process only supports GRAPH_MODE and Ascend target currently. It only supports dataset sink mode. - Parameters
- train_dataset (Dataset) – A training dataset iterator. If train_dataset is defined, training graphs will be built. Default: - None.
- valid_dataset (Dataset) – An evaluating dataset iterator. If valid_dataset is defined, evaluation graphs will be built, and metrics in Model can not be None. Default: - None.
- sink_size (int) – Control the number of steps for each sinking. Default: - -1.
- epoch (int) – Control the training epochs. Default: - 1.
- sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: - True.
 
 - Examples - >>> from mindspore import nn >>> from mindspore.train import Model >>> from mindspore.amp import FixedLossScaleManager >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss_scale_manager = FixedLossScaleManager() >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, ... loss_scale_manager=loss_scale_manager) >>> model.build(dataset, epoch=2) >>> model.train(2, dataset) 
 - eval(valid_dataset, callbacks=None, dataset_sink_mode=False)[source]
- Evaluation API. - Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode. - Note - If dataset_sink_mode is True, data will be sent to device. At this point, the dataset will be bound to this model, so the dataset cannot be used by other models. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. - The interface builds the computational graphs and then executes the computational graphs. However, when the Model.build is executed first, it only performs the graphs execution. - Parameters
- valid_dataset (Dataset) – Dataset to evaluate the model. 
- callbacks (Optional[list(Callback), Callback]) – List of callback objects or callback object, which should be executed while evaluation. Default: - None.
- dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: - False.
 
- Returns
- Dict, the key is the metric name defined by users and the value is the metrics value for the model in the test mode. 
 - Examples - >>> from mindspore import nn >>> from mindspore.train import Model >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> acc = model.eval(dataset, dataset_sink_mode=False) 
 - property eval_network
- Get the model's eval network. - Returns
- Object, the instance of evaluate network. 
 
 - fit(epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None, dataset_sink_mode=False, valid_dataset_sink_mode=False, sink_size=- 1, initial_epoch=0)[source]
- Fit API. - Evaluation process will be performed during training process if valid_dataset is provided. - More details please refer to - mindspore.train.Model.train()and- mindspore.train.Model.eval().- Parameters
- epoch (int) – Total training epochs. Generally, train network will be trained on complete dataset per epoch. If dataset_sink_mode is set to True and sink_size is greater than 0, each epoch will train sink_size steps instead of total steps of dataset. If epoch used with initial_epoch, it is to be understood as "final epoch". 
- train_dataset (Dataset) – A training dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network. 
- valid_dataset (Dataset) – Dataset to evaluate the model. If valid_dataset is provided, evaluation process will be performed on the end of training process. Default: - None.
- valid_frequency (int, list) – Only relevant if valid_dataset is provided. If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. valid_frequency=2 runs validation every 2 epochs. If a list, specifies the epochs on which to run validation, e.g. valid_frequency=[1, 5] runs validation at the end of the 1st, 5th epochs. Default: - 1.
- callbacks (Optional[list[Callback], Callback]) – List of callback objects or callback object, which should be executed while training. Default: - None.
- dataset_sink_mode (bool) – Determines whether to pass the train data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: - False.
- valid_dataset_sink_mode (bool) – Determines whether to pass the validation data through dataset channel. Default: - False.
- sink_size (int) – Control the number of steps for each sinking. sink_size is invalid if dataset_sink_mode is False. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. Default: - -1.
- initial_epoch (int) – Epoch at which to start train, it useful for resuming a previous training run. Default: - 0.
 
 - Examples - >>> from mindspore import nn >>> from mindspore.train import Model >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> train_dataset = create_dataset("train") >>> valid_dataset = create_dataset("test") >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"}) >>> model.fit(2, train_dataset, valid_dataset) 
 - infer_predict_layout(*predict_data, skip_backend_compile=False)[source]
- Generate parameter layout for the predict network when using AutoParallel(cell) to enable parallel mode. - Data could be a single tensor or multiple tensors. - Note - Batch data should be put together in one tensor. - Parameters
- predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional) – The predict data, can be a single tensor, a list of tensor, or a tuple of tensor. 
- skip_backend_compile (bool) – Only run the frontend compile process, skip the compile process on the device side. Set this flag to True may lead to recompiling process can not hit cache. 
 
- Returns
- Dict, Parameter layout dictionary used for load distributed checkpoint. Using as one of input parameters of load_distributed_checkpoint, always. 
- Raises
- RuntimeError – If not in GRAPH_MODE. 
 - Examples - >>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import Tensor >>> from mindspore.train import Model >>> from mindspore.ops import operations as P >>> from mindspore import context >>> from mindspore.communication import init >>> from mindspore.parallel.auto_parallel import AutoParallel >>> >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(128, 768, activation='relu') >>> self.fc2 = nn.Dense(128, 768, activation='relu') >>> self.fc3 = nn.Dense(128, 768, activation='relu') >>> self.fc4 = nn.Dense(768, 768, activation='relu') >>> self.relu4 = nn.ReLU() >>> self.relu5 = nn.ReLU() >>> self.transpose = P.Transpose() >>> self.matmul1 = P.MatMul() >>> self.matmul2 = P.MatMul() >>> >>> def construct(self, x): >>> q = self.fc1(x) >>> k = self.fc2(x) >>> v = self.fc3(x) >>> k = self.transpose(k, (1, 0)) >>> c = self.relu4(self.matmul1(q, k)) >>> s = self.relu5(self.matmul2(c, v)) >>> s = self.fc4(s) >>> return s >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> inputs = Tensor(np.ones([32, 128]).astype(np.float32)) >>> net = Net() >>> parallel_net = AutoParallel(net, parallel_mode='semi_auto') >>> model = Model(parallel_net) >>> predict_map = model.infer_predict_layout(inputs) 
 - infer_train_layout(train_dataset, dataset_sink_mode=True, sink_size=- 1)[source]
- Generate parameter layout for the train network when using AutoParallel(cell) to enable parallel mode. - Only dataset sink mode is supported for now. - Warning - This is an experimental API that is subject to change or deletion. - Note - This is a pre-compile function. The arguments should be the same as model.train() function. - Parameters
- train_dataset (Dataset) – A training dataset iterator. If there is no loss_fn, a tuple with multiple data (data1, data2, data3, …) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned. The data and label would be passed to the network and loss function respectively. 
- dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: - True.
- sink_size (int) – Control the number of steps for each sinking. If dataset_sink_mode is False, set sink_size as invalid. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. Default: - -1.
 
- Returns
- Dict, Parameter layout dictionary used for load distributed checkpoint 
 - Examples - >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on >>> # mindspore.cn. >>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn >>> from mindspore.train import Model >>> from mindspore.communication import init >>> from mindspore.parallel.auto_parallel import AutoParallel >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> parallel_net = AutoParallel(net) >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss_scale_manager = ms.FixedLossScaleManager() >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None, ... loss_scale_manager=loss_scale_manager) >>> layout_dict = model.infer_train_layout(dataset) 
 - predict(*predict_data, backend=None, config=None)[source]
- Generate output predictions for the input samples. - Parameters
- predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional) – The predict data, can be a single tensor, a list of tensor, or a tuple of tensor. 
- backend (str) – Select predict backend, this parameter is an experimental feature and is mainly used for MindSpore Lite cloud-side inference. Default: - None.
- config (dict, optional) – - The config includes two parts: config_path (configPath, str) and config_item (str, dict). When the config_item is set, its priority is higher than the config_path. Set the ranking table file for inference. The content of the configuration file is as follows: - config_path defines the path of the configuration file, which is used to pass user-defined options during model building. In the following scenarios, users may need to set parameters. For example: "/home/user/config.ini". Default value: - "", here is the content of the config.ini file:- [ascend_context] rank_table_file = [path_a](storage initial path of the rank table file) [execution_plan] [op_name1] = data_type:float16 (operator named op_name1 is set to data type float16) [op_name2] = data_type:float32 (operator named op_name2 is set to data type float32) - When only the config_path is configured, it is done as follows: - config = {"configPath" : "/home/user/config.ini"} - When only the config_dict is configured, it is done as follows: - config = {"ascend_context" : {"rank_table_file" : "path_b"}, "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}} - When both the config_path and the config_dict are configured, it is done as follows: - config = {"configPath" : "/home/user/config.ini", "ascend_context" : {"rank_table_file" : "path_b"}, "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}} - Note that both the "configPath" is configured in the config_dict and the config_item, in this case, the path_b in the config_dict takes precedence. 
 
- Returns
- Tensor, array(s) of predictions. 
 - Examples - >>> import numpy as np >>> import mindspore >>> from mindspore import Tensor >>> from mindspore.train import Model >>> >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32) >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> model = Model(LeNet5()) >>> result = model.predict(input_data) 
 - property predict_network
- Get the model's predict network. - Returns
- Object, the instance of predict network. 
 
 - train(epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=- 1, initial_epoch=0)[source]
- Training API. - When setting pynative mode or CPU, the training process will be performed with dataset not sink. - Note - If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. - When dataset_sink_mode is True, the step_end method of the instance of Callback will be called at the end of step in PyNative mode, or will be called at the end of epoch in Graph mode. - If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models. - If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size elements of the dataset. The next epoch continues to traverse from the end position of the previous traversal. - The interface builds the computational graphs and then executes the computational graphs. However, when the Model.build is executed first, it only performs the graphs execution. - Parameters
- epoch (int) – Total training epochs. Generally, train network will be trained on complete dataset per epoch. If dataset_sink_mode is set to True and sink_size is greater than 0, each epoch will train sink_size steps instead of total steps of dataset. If epoch used with initial_epoch, it is to be understood as "final epoch". 
- train_dataset (Dataset) – A training dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network. 
- callbacks (Optional[list[Callback], Callback]) – List of callback objects or callback object, which should be executed while training. Default: - None.
- dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: - False.
- sink_size (int) – Control the number of steps for each sinking. sink_size is invalid if dataset_sink_mode is False. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. Default: -1. 
- initial_epoch (int) – Epoch at which to start train, it used for resuming a previous training run. Default: 0. 
 
 - Examples - >>> import mindspore as ms >>> from mindspore import nn >>> from mindspore.train import Model >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.6.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> loss_scale_manager = ms.FixedLossScaleManager(1024., False) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, ... loss_scale_manager=loss_scale_manager) >>> model.train(2, dataset) 
 - property train_network
- Get the model's train network. - Returns
- Object, the instance of train network.