mindspore.Model

class mindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level='O0', boost_level='O0', **kwargs)[source]

High-Level API for training or inference.

Model groups layers into an object with training and inference features based on the arguments.

Parameters
  • network (Cell) – A training or testing network.

  • loss_fn (Cell) – Objective function. If loss_fn is None, the network should contain the calculation of loss and parallel if needed. Default: None.

  • optimizer (Cell) – Optimizer for updating the weights. If optimizer is None, the network needs to do backpropagation and update weights. Default value: None.

  • metrics (Union[dict, set]) – A Dictionary or a set of metrics for model evaluation. eg: {‘accuracy’, ‘recall’}. Default: None.

  • eval_network (Cell) – Network for evaluation. If not defined, network and loss_fn would be wrapped as eval_network . Default: None.

  • eval_indexes (list) – It is used when eval_network is defined. If eval_indexes is None by default, all outputs of the eval_network would be passed to metrics. If eval_indexes is set, it must contain three elements: the positions of loss value, predicted value and label in outputs of the eval_network. In this case, the loss value will be passed to the Loss metric, the predicted value and label will be passed to other metrics. :func: mindindex.nn.metric.set_indexes is recommended instead of eval_indexes. Default: None.

  • amp_level (str) –

    Option for argument level in mindspore.build_train_network(), level for mixed precision training. Supports [“O0”, “O2”, “O3”, “auto”]. Default: “O0”.

    • O0: Do not change.

    • O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.

    • O3: Cast network to float16, the batchnorm is also cast to float16, loss scale will not be used.

    • auto: Set level to recommended level in different devices. Set level to O2 on GPU, set level to O3 on Ascend. The recommended level is chosen by the export experience, not applicable to all scenarios. User should specify the level for special network.

    O2 is recommended on GPU, O3 is recommended on Ascend. The batchnorm strategy can be changed by keep_batchnorm_fp32 settings in kwargs. keep_batchnorm_fp32 must be a bool. The loss scale strategy can be changed by loss_scale_manager setting in kwargs. loss_scale_manager should be a subclass of mindspore.LossScaleManager. The more detailed explanation of amp_level setting can be found at mindspore.build_train_network.

  • boost_level (str) –

    Option for argument level in mindspore.boost, level for boost mode training. Supports [“O0”, “O1”, “O2”]. Default: “O0”.

    • O0: Do not change.

    • O1: Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy.

    • O2: Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%.

    If you want to config boost mode by yourself, you can set boost_config_dict as boost.py.

Examples

>>> from mindspore import Model, nn
>>>
>>> class Net(nn.Cell):
...     def __init__(self, num_class=10, num_channel=1):
...         super(Net, self).__init__()
...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.flatten = nn.Flatten()
...
...     def construct(self, x):
...         x = self.max_pool2d(self.relu(self.conv1(x)))
...         x = self.max_pool2d(self.relu(self.conv2(x)))
...         x = self.flatten(x)
...         x = self.relu(self.fc1(x))
...         x = self.relu(self.fc2(x))
...         x = self.fc3(x)
...         return x
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> model.train(2, dataset)
build(train_dataset=None, valid_dataset=None, sink_size=- 1, epoch=1, jit_config=None)[source]

Build computational graphs and data graphs with the sink mode.

Warning

This is an experimental prototype that is subject to change or deletion.

Note

The interface builds the computational graphs, when the interface is executed first, ‘Model.train’ only performs the graphs execution. Pre-build process only supports GRAPH_MODE and Ascend target currently. It only supports dataset sink mode.

Parameters
  • train_dataset (Dataset) – A training dataset iterator. If train_dataset is defined, training graphs will be built. Default: None.

  • valid_dataset (Dataset) – An evaluating dataset iterator. If valid_dataset is defined, evaluation graphs will be built, and metrics in Model can not be None. Default: None.

  • sink_size (int) – Control the amount of data in each sink. Default: -1.

  • epoch (int) – Control the training epochs. Default: 1.

  • jit_config (Union[str, str]) –

    Control the jit config. By default, if set to None, the graph will compile as the default behavior. You can customize the compile config with a dictionary. For example, you can set {‘jit_level’: ‘o0’} to control the jit level. The data that supports control is shown below. Default: None.

    • jit_level (string): Control the graph compile optimize level. Optional: o0/o1. Default: o1. If set to o0, the graph compiling will pass the combine like graph phase.

Examples

>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.build(dataset, epoch=2)
>>> model.train(2, dataset)
eval(valid_dataset, callbacks=None, dataset_sink_mode=True)[source]

Evaluation API.

Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.

Note

If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M.

If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models.

The interface builds the computational graphs and then executes the computational graphs. However, when the Model.build is executed first, it only performs the graphs execution.

Parameters
  • valid_dataset (Dataset) – Dataset to evaluate the model.

  • callbacks (Optional[list(Callback), Callback]) – List of callback objects or callback object, which should be executed while evaluation. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: True.

Returns

Dict, the key is the metric name defined by users and the value is the metrics value for the model in the test mode.

Examples

>>> from mindspore import Model, nn
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> acc = model.eval(dataset, dataset_sink_mode=False)
property eval_network

Get the model’s eval network.

Returns

Object, the instance of evaluate network.

infer_predict_layout(*predict_data)[source]

Generate parameter layout for the predict network in ‘AUTO_PARALLEL’ or ‘SEMI_AUTO_PARALLEL’ mode.

Data could be a single tensor or multiple tensors.

Note

Batch data should be put together in one tensor.

Parameters

predict_data (Tensor) – One tensor or multiple tensors of predict data.

Returns

Dict, Parameter layout dictionary used for load distributed checkpoint. Using as one of input parameters of load_distributed_checkpoint, always.

Raises

RuntimeError – If not in GRAPH_MODE.

Examples

>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
>>> # mindspore.cn.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor
>>> from mindspore.context import ParallelMode
>>> from mindspore.communication import init
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> predict_map = model.infer_predict_layout(input_data)
infer_train_layout(train_dataset, dataset_sink_mode=True, sink_size=- 1)[source]

Generate parameter layout for the train network in ‘AUTO_PARALLEL’ or ‘SEMI_AUTO_PARALLEL’ mode. Only dataset sink mode is supported for now.

Warning

This is an experimental prototype that is subject to change and/or deletion.

Note

This is a pre-compile function. The arguments should be the same as model.train() function.

Parameters
  • train_dataset (Dataset) – A training dataset iterator. If there is no loss_fn, a tuple with multiple data (data1, data2, data3, …) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned. The data and label would be passed to the network and loss function respectively.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: True.

  • sink_size (int) – Control the amount of data in each sink. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.

Returns

Dict, Parameter layout dictionary used for load distributed checkpoint

Examples

>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
>>> # mindspore.cn.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor, nn, FixedLossScaleManager
>>> from mindspore.context import ParallelMode
>>> from mindspore.communication import init
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> layout_dict = model.infer_train_layout(dataset)
predict(*predict_data)[source]

Generate output predictions for the input samples.

Parameters

predict_data (Optional[Tensor, list[Tensor], tuple[Tensor]]) – The predict data, can be a single tensor, a list of tensor, or a tuple of tensor.

Returns

Tensor, array(s) of predictions.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, Tensor
>>>
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> result = model.predict(input_data)
property predict_network

Get the model’s predict network.

Returns

Object, the instance of predict network.

train(epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=- 1)[source]

Training API.

When setting pynative mode or CPU, the training process will be performed with dataset not sink.

Note

If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M.

When dataset_sink_mode is True, the step_end method of the instance of Callback will be called at the end of epoch.

If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models.

If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size elements of the dataset. The next epoch continues to traverse from the end position of the previous traversal.

The interface builds the computational graphs and then executes the computational graphs. However, when the Model.build is executed first, it only performs the graphs execution.

Parameters
  • epoch (int) – Total training epochs. Generally, train network will be trained on complete dataset per epoch. If dataset_sink_mode is set to True and sink_size is greater than 0, each epoch will train sink_size steps instead of total steps of dataset.

  • train_dataset (Dataset) – A training dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network.

  • callbacks (Optional[list[Callback], Callback]) – List of callback objects or callback object, which should be executed while training. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: True.

  • sink_size (int) – Control the amount of data in each sink. sink_size is invalid if dataset_sink_mode is False. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. Default: -1.

Examples

>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
property train_network

Get the model’s train network.

Returns

Object, the instance of train network.