


class mindspore.train.model.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level='O0', **kwargs)[source]

High-Level API for Training or Testing.

Model groups layers into an object with training and inference features.

  • network (Cell) – A training or testing network.

  • loss_fn (Cell) – Objective function, if loss_fn is None, the network should contain the logic of loss and grads calculation, and the logic of parallel if needed. Default: None.

  • optimizer (Cell) – Optimizer for updating the weights. Default: None.

  • metrics (Union[dict, set]) – A Dictionary or a set of metrics to be evaluated by the model during training and testing. eg: {‘accuracy’, ‘recall’}. Default: None.

  • eval_network (Cell) – Network for evaluation. If not defined, network and loss_fn would be wrapped as eval_network. Default: None.

  • eval_indexes (list) – When defining the eval_network, if eval_indexes is None, all outputs of the eval_network would be passed to metrics, otherwise eval_indexes must contain three elements, including the positions of loss value, predicted value and label. The loss value would be passed to the Loss metric, the predicted value and label would be passed to other metric. Default: None.

  • amp_level (str) –

    Option for argument level in mindspore.amp.build_train_network, level for mixed precision training. Supports [“O0”, “O2”, “O3”, “auto”]. Default: “O0”.

    • O0: Do not change.

    • O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.

    • O3: Cast network to float16, with additional property ‘keep_batchnorm_fp32=False’.

    • auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set level to O3 Ascend. The recommended level is choose by the export experience, cannot always generalize. User should specify the level for special network.

    O2 is recommended on GPU, O3 is recommended on Ascend.

  • loss_scale_manager (Union[None, LossScaleManager]) – If it is None, the loss would not be scaled. Otherwise, scale the loss by LossScaleManager and optimizer can not be None.It is a key argument. e.g. Use loss_scale_manager=None to set the value.

  • keep_batchnorm_fp32 (bool) – Keep Batchnorm running in float32. If it is set to true, the level setting before will be overwritten. Default: True.


>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>>         self.bn = nn.BatchNorm2d(64)
>>>         self.relu = nn.ReLU()
>>>         self.flatten = nn.Flatten()
>>>         self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>     def construct(self, x):
>>>         x = self.conv(x)
>>>         x = self.bn(x)
>>>         x = self.relu(x)
>>>         x = self.flatten(x)
>>>         out = self.fc(x)
>>>         return out
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
eval(valid_dataset, callbacks=None, dataset_sink_mode=True)[source]

Evaluation API where the iteration is controlled by python front-end.

Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.


If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M.

  • valid_dataset (Dataset) – Dataset to evaluate the model.

  • callbacks (list) – List of callback objects which should be executed while training. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: True.


Dict, which returns the loss value and metrics values for the model in the test mode.


>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)

Generate output predictions for the input samples.

Data could be a single tensor, a list of tensor, or a tuple of tensor.


Batch data should be put together in one tensor.


predict_data (Tensor) – Tensor of predict data. can be array, list or tuple.


Tensor, array(s) of predictions.


>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.predict(input_data)
train(epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=- 1)[source]

Training API where the iteration is controlled by python front-end.

When setting pynative mode or CPU, the training process will be performed with dataset not sink.


If dataset_sink_mode is True, epoch of training should be equal to the count of repeat operation in dataset processing. Otherwise, errors could occur since the amount of data is not equal to the required amount of training . If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. If sink_size > 0, each epoch the dataset can be traversed unlimited times until you get sink_size elements of the dataset. Next epoch continues to traverse from the end position of the previous traversal.

  • epoch (int) – Generally, total number of iterations on the data per epoch. When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size steps on the data instead of total number of iterations.

  • train_dataset (Dataset) – A training dataset iterator. If there is no loss_fn, a tuple with multiple data (data1, data2, data3, …) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned. The data and label would be passed to the network and loss function respectively.

  • callbacks (list) – List of callback objects which should be executed while training. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: True. Configure pynative mode or CPU, the training process will be performed with dataset not sink.

  • sink_size (int) – Control the amount of data in each sink. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.


>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)


Dataset help for minddata dataset

class mindspore.train.dataset_helper.DatasetHelper(dataset, dataset_sink_mode=True, sink_size=- 1, epoch_num=1)[source]

DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.

According to different contexts, change the iterations of dataset and use the same iteration for loop in different contexts.


The iteration of DatasetHelper will provide one epoch data.

  • dataset (DataSet) – The training dataset iterator.

  • dataset_sink_mode (bool) – If true use GetNext to fetch the data, or else feed the data from host. Default: True.

  • sink_size (int) – Control the amount of data in each sink. If sink_size=-1, sink the complete dataset for each epoch. If sink_size>0, sink sink_size data for each epoch. Default: -1.

  • epoch_num (int) – Control the number of epoch data to send. Default: 1.


>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>>     outputs = network(*inputs)

continue send data to device at the beginning of epoch.


Get sink_size for each iteration.


Free up resources about data sink.


Get the types and shapes from dataset on the current configuration.

mindspore.train.dataset_helper.connect_network_with_dataset(network, dataset_helper)[source]

Connect the network with dataset in dataset_helper.

This function wraps the input network with ‘GetNext’ so that the data can be fetched automatically from the data channel corresponding to the ‘queue_name’ and passed to the input network during forward computation.


In the case of running the network on Ascend in graph mode, this function will wrap the input network with ‘GetNext’, in other cases, the input network will be returned with no change. The ‘GetNext’ is required to get data only in sink mode, so this function is not applicable to no-sink mode.

  • network (Cell) – The training network for dataset.

  • dataset_helper (DatasetHelper) – A class to process the MindData dataset, it provides the type, shape and queue name of the dataset to wrap the GetNext.


Cell, a new network wrapped with ‘GetNext’ in the case of running the task on Ascend in graph mode, otherwise it is the input network.


>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> net = Net()
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)



User can use SummaryRecord to dump the summary data, the summary is a series of operations to collect data for analysis and visualization.

class mindspore.train.summary.SummaryRecord(log_dir, file_prefix='events', file_suffix='_MS', network=None, max_file_size=None)[source]

SummaryRecord is used to record the summary data and lineage data.

The API will create a summary file and lineage files lazily in a given directory and writes data to them. It writes the data to files by executing the ‘record’ method. In addition to recording the data bubbled up from the network by defining the summary operators, SummaryRecord also supports to record extra data which can be added by calling add_value.


  1. Make sure to close the SummaryRecord at the end, otherwise the process will not exit. Please see the Example section below to learn how to close properly in two ways.

  2. Only one SummaryRecord instance is allowed at a time, otherwise it will cause data writing problems.

  3. SummaryRecord only supports Linux systems.

  • log_dir (str) – The log_dir is a directory location to save the summary.

  • file_prefix (str) – The prefix of file. Default: “events”.

  • file_suffix (str) – The suffix of file. Default: “_MS”.

  • network (Cell) – Obtain a pipeline through network for saving graph summary. Default: None.

  • max_file_size (Optional[int]) – The maximum size of each file that can be written to disk (in bytes). Unlimited by default. For example, to write not larger than 4GB, specify max_file_size=4 * 1024**3.

  • TypeError – If the type of max_file_size is not int, or the type of file_prefix or file_suffix is not str.

  • RuntimeError – If the log_dir is not a normalized absolute path name.


>>> # use in with statement to auto close
>>> with SummaryRecord(log_dir="./summary_dir") as summary_record:
>>>     pass
>>> # use in try .. finally .. to ensure closing
>>> try:
>>>     summary_record = SummaryRecord(log_dir="./summary_dir")
>>> finally:
>>>     summary_record.close()
add_value(plugin, name, value)[source]

Add value to be recorded later.

When the plugin is ‘tensor’, ‘scalar’, ‘image’ or ‘histogram’, the name should be the tag name, and the value should be a Tensor.

When the plugin is ‘graph’, the value should be a GraphProto.

When the plugin is ‘dataset_graph’, ‘train_lineage’, ‘eval_lineage’, or ‘custom_lineage_data’, the value should be a proto message.

  • plugin (str) – The value of the plugin.

  • name (str) – The value of the name.

  • value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]) –

    The value to store.

    • The data type of value should be ‘GraphProto’ when the plugin is ‘graph’.

    • The data type of value should be ‘Tensor’ when the plugin is ‘scalar’, ‘image’, ‘tensor’ or ‘histogram’.

    • The data type of value should be ‘TrainLineage’ when the plugin is ‘train_lineage’.

    • The data type of value should be ‘EvaluationLineage’ when the plugin is ‘eval_lineage’.

    • The data type of value should be ‘DatasetGraph’ when the plugin is ‘dataset_graph’.

    • The data type of value should be ‘UserDefinedInfo’ when the plugin is ‘custom_lineage_data’.



>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>>     summary_record.add_value('scalar', 'loss', Tensor(0.1))

Flush all events and close summary records. Please use the statement to autoclose.


>>> try:
>>>     summary_record = SummaryRecord(log_dir="./summary_dir")
>>> finally:
>>>     summary_record.close()

Flush the event file to disk.

Call it to make sure that all pending events have been written to disk.


>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>>     summary_record.flush()
property log_dir

Get the full path of the log file.


str, the full path of log file.


>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>>     print(summary_record.log_dir)
record(step, train_network=None, plugin_filter=None)[source]

Record the summary.

  • step (int) – Represents training step number.

  • train_network (Cell) – The network to call the callback.

  • plugin_filter (Optional[Callable[[str], bool]]) – The filter function, which is used to filter out plugins from being written by returning False.


bool, whether the record process is successful or not.


>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>>     summary_record.record(step=2)

Set the mode for the recorder to be aware. The mode is set to ‘train’ by default.


mode (str) – The mode to be set, which should be ‘train’ or ‘eval’.


ValueError – When the mode is not recognized.


>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>>     summary_record.set_mode('eval')


Callback related classes and functions.

class mindspore.train.callback.Callback[source]

Abstract base class used to build a callback class. Callbacks are context managers which will be entered and exited when passing into the Model. You can use this mechanism to initialize and release resources automatically.

Callback function will execute some operations in the current step or epoch.


>>> class Print_info(Callback):
>>>     def step_end(self, run_context):
>>>         cb_params = run_context.original_args()
>>>         print(cb_params.cur_epoch_num)
>>>         print(cb_params.cur_step_num)
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)

Called once before the network executing.


run_context (RunContext) – Include some information of the model.


Called once after network training.


run_context (RunContext) – Include some information of the model.


Called before each epoch beginning.


run_context (RunContext) – Include some information of the model.


Called after each epoch finished.


run_context (RunContext) – Include some information of the model.


Called before each epoch beginning.


run_context (RunContext) – Include some information of the model.


Called after each step finished.


run_context (RunContext) – Include some information of the model.

class mindspore.train.callback.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False)[source]

The configuration of model checkpoint.


During the training process, if dataset is transmitted through the data channel, It is suggested to set ‘save_checkpoint_steps’ to an integer multiple of loop_size. Otherwise, the time to save the checkpoint may be biased.

  • save_checkpoint_steps (int) – Steps to save checkpoint. Default: 1.

  • save_checkpoint_seconds (int) – Seconds to save checkpoint. Default: 0. Can’t be used with save_checkpoint_steps at the same time.

  • keep_checkpoint_max (int) – Maximum number of checkpoint files can be saved. Default: 5.

  • keep_checkpoint_per_n_minutes (int) – Keep one checkpoint every n minutes. Default: 0. Can’t be used with keep_checkpoint_max at the same time.

  • integrated_save (bool) – Whether to perform integrated save function in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.

  • async_save (bool) – Whether asynchronous execution saves the checkpoint to a file. Default: False


ValueError – If the input_param is None or 0.


>>> config = CheckpointConfig()
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb)
property async_save

Get the value of _async_save.


Get the policy of checkpoint.

property integrated_save

Get the value of _integrated_save.

property keep_checkpoint_max

Get the value of _keep_checkpoint_max.

property keep_checkpoint_per_n_minutes

Get the value of _keep_checkpoint_per_n_minutes.

property save_checkpoint_seconds

Get the value of _save_checkpoint_seconds.

property save_checkpoint_steps

Get the value of _save_checkpoint_steps.

class mindspore.train.callback.LossMonitor(per_print_times=1)[source]

Monitor the loss in training.

If the loss is NAN or INF, it will terminate training.


If per_print_times is 0, do not print loss.


per_print_times (int) – Print the loss each every time. Default: 1.


ValueError – If print_step is not an integer or less than zero.

class mindspore.train.callback.ModelCheckpoint(prefix='CKP', directory=None, config=None)[source]

The checkpoint callback class.

It is called to combine with train process and save the model and network parameters after traning.

  • prefix (str) – The prefix name of checkpoint files. Default: “CKP”.

  • directory (str) – The path of the folder which will be saved in the checkpoint file. Default: None.

  • config (CheckpointConfig) – Checkpoint strategy configuration. Default: None.

  • ValueError – If the prefix is invalid.

  • TypeError – If the config is not CheckpointConfig type.


Save the last checkpoint after training finished.


run_context (RunContext) – Context of the train running.

property latest_ckpt_file_name

Return the latest checkpoint path and file name.


Save the checkpoint at the end of step.


run_context (RunContext) – Context of the train running.

class mindspore.train.callback.RunContext(original_args)[source]

Provides information about the model.

Provides information about original request to model function. Callback objects can stop the loop by calling request_stop() of run_context.


original_args (dict) – Holding the related information of model.


Returns whether a stop is requested or not.


bool, if true, model.train() stops iterations.


Get the _original_args object.


Dict, an object that holds the original arguments of model.


Sets stop requirement during training.

Callbacks can use this function to request stop of iterations. model.train() checks whether this is called or not.

class mindspore.train.callback.SummaryCollector(summary_dir, collect_freq=10, collect_specified_data=None, keep_default_action=True, custom_lineage_data=None, collect_tensor_freq=None, max_file_size=None)[source]

SummaryCollector can help you to collect some common information.

It can help you to collect loss, learning late, computational graph and so on. SummaryCollector also enables the summary operator to collect data from a summary file.


  1. Multiple SummaryCollector instances in callback list are not allowed.

  2. Not all information is collected at the training phase or at the eval phase.

  3. SummaryCollector always record the data collected by the summary operator.

  4. SummaryCollector only supports Linux systems.

  • summary_dir (str) – The collected data will be persisted to this directory. If the directory does not exist, it will be created automatically.

  • collect_freq (int) – Set the frequency of data collection, it should be greater then zero, and the unit is step. Default: 10. If a frequency is set, we will collect data when (current steps % freq) equals to 0, and the first step will be collected at any time. It is important to note that if the data sink mode is used, the unit will become the epoch. It is not recommended to collect data too frequently, which can affect performance.

  • collect_specified_data (Union[None, dict]) –

    Perform custom operations on the collected data. Default: None. By default, if set to None, all data is collected as the default behavior. You can customize the collected data with a dictionary. For example, you can set {‘collect_metric’: False} to control not collecting metrics. The data that supports control is shown below.

    • collect_metric: Whether to collect training metrics, currently only the loss is collected. The first output will be treated as the loss and it will be averaged. Optional: True/False. Default: True.

    • collect_graph: Whether to collect the computational graph. Currently, only training computational graph is collected. Optional: True/False. Default: True.

    • collect_train_lineage: Whether to collect lineage data for the training phase, this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.

    • collect_eval_lineage: Whether to collect lineage data for the evaluation phase, this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.

    • collect_input_data: Whether to collect dataset for each training. Currently only image data is supported. Optional: True/False. Default: True.

    • collect_dataset_graph: Whether to collect dataset graph for the training phase. Optional: True/False. Default: True.

    • histogram_regular: Collect weight and bias for parameter distribution page and displayed in MindInsight. This field allows regular strings to control which parameters to collect. Default: None, it means only the first five parameters are collected. It is not recommended to collect too many parameters at once, as it can affect performance. Note that if you collect too many parameters and run out of memory, the training will fail.

  • keep_default_action (bool) – This field affects the collection behavior of the ‘collect_specified_data’ field. Optional: True/False, Default: True. True: it means that after specified data is set, non-specified data is collected as the default behavior. False: it means that after specified data is set, only the specified data is collected, and the others are not collected.

  • custom_lineage_data (Union[dict, None]) – Allows you to customize the data and present it on the MingInsight lineage page. In the custom data, the type of the key supports str, and the type of value supports str, int and float. Default: None, it means there is no custom data.

  • collect_tensor_freq (Optional[int]) – The same semantics as the collect_freq, but controls TensorSummary only. Because TensorSummary data is too large to be compared with other summary data, this parameter is used to reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 20, but it will not exceed the number of steps for collecting other summary data. Default: None, which means to follow the behavior as described above. For example, given collect_freq=10, when the total steps is 600, TensorSummary will be collected 20 steps, while other summary data 61 steps, but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps. Also note that when in parallel mode, the total steps will be splitted evenly, which will affect the number of steps TensorSummary will be collected.

  • max_file_size (Optional[int]) – The maximum size in bytes of each file that can be written to the disk. Default: None, which means no limit. For example, to write not larger than 4GB, specify max_file_size=4 * 1024**3.

  • ValueError – If the parameter value is not expected.

  • TypeError – If the parameter type is not expected.

  • RuntimeError – If an error occurs during data collection.


>>> # Simple usage:
>>> summary_collector = SummaryCollector(summary_dir='./summary_dir')
>>> model.train(epoch, dataset, callbacks=summary_collector)
>>> # Do not collect metric and collect the first layer parameter, others are collected by default
>>> specified={'collect_metric': False, 'histogram_regular': '^conv1.*'}
>>> summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
>>> model.train(epoch, dataset, callbacks=summary_collector)
>>> # Only collect metric, custom lineage data and record data that collected by the summary operator,
>>> # others are not collected
>>> specified = {'collect_metric': True}
>>> summary_collector = SummaryCollector('./summary_dir',
>>>                                      collect_specified_data=specified,
>>>                                      keep_default_action=False,
>>>                                      custom_lineage_data={'version': 'resnet50_v1'}
>>>                                      )
>>> model.train(epoch, dataset, callbacks=summary_collector)
class mindspore.train.callback.TimeMonitor(data_size=None)[source]

Monitor the time in training.


data_size (int) – Dataset size. Default: None.


Model and parameters serialization.


Build strategy of every parameter in network.


strategy_filename (str) – Name of strategy file.


Dictionary, whose key is parameter name and value is slice strategy of this parameter.



>>> strategy_filename = "./strategy_train.ckpt"
>>> strategy = build_searched_strategy(strategy_filename)
mindspore.train.serialization.export(net, *inputs, file_name, file_format='AIR')[source]

Export the MindSpore prediction model to a file in the specified format.

  • net (Cell) – MindSpore network.

  • inputs (Tensor) – Inputs of the net.

  • file_name (str) – File name of the model to be exported.

  • file_format (str) –

    MindSpore currently supports ‘AIR’, ‘ONNX’ and ‘MINDIR’ format for exported model.

    • AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model. Recommended suffix for output file is ‘.air’.

    • ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. Recommended suffix for output file is ‘.onnx’.

    • MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format for MindSpore models. Recommended suffix for output file is ‘.mindir’.

mindspore.train.serialization.load_checkpoint(ckpt_file_name, net=None)[source]

Loads checkpoint info from a specified file.

  • ckpt_file_name (str) – Checkpoint file name.

  • net (Cell) – Cell network. Default: None


Dict, key is parameter name, value is a Parameter.


ValueError – Checkpoint file is incorrect.

mindspore.train.serialization.load_param_into_net(net, parameter_dict)[source]

Loads parameters into network.

  • net (Cell) – Cell network.

  • parameter_dict (dict) – Parameter dictionary.


TypeError – Argument is not a Cell, or parameter_dict is not a Parameter dictionary.

mindspore.train.serialization.merge_sliced_parameter(sliced_parameters, strategy=None)[source]

Merge parameter slices to one whole parameter.

  • sliced_parameters (list[Parameter]) – Parameter slices in order of rank_id.

  • strategy (dict) –

    Parameter slice strategy, the default is None. If strategy is None, just merge parameter slices in 0 axis order.

    • key (str): Parameter name.

    • value (<class ‘node_strategy_pb2.ParallelLayouts’>): Slice strategy of this parameter.


Parameter, the merged parameter which has the whole data.

  • ValueError – Failed to merge.

  • TypeError – The sliced_parameters is incorrect or strategy is not dict.

  • KeyError – The parameter name is not in keys of strategy.


>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [
>>>                      Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
>>>                                "network.embedding_table"),
>>>                      Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
>>>                                "network.embedding_table"),
>>>                      Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
>>>                                "network.embedding_table"),
>>>                      Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
>>>                                "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)

Loads Print data from a specified file.


print_file_name (str) – The file name of saved print data.


List, element of list is Tensor.


ValueError – The print file may be empty, please make sure enter the correct file name.

mindspore.train.serialization.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False)[source]

Saves checkpoint info to a specified file.

  • save_obj (nn.Cell or list) – The cell object or data list(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name would be string, and the type of param_data would be parameter or tensor).

  • ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.

  • integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default: True

  • async_save (bool) – Whether asynchronous execution saves the checkpoint to a file. Default: False


TypeError – If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and async_save are not bool type.


Auto mixed precision.

mindspore.train.amp.build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs)[source]

Build the mixed precision training cell automatically.

  • network (Cell) – Definition of the network.

  • loss_fn (Union[None, Cell]) – Definition of the loss_fn. If None, the network should have the loss inside. Default: None.

  • optimizer (Optimizer) – Optimizer to update the Parameter.

  • level (str) –

    Supports [“O0”, “O2”, “O3”, “auto”]. Default: “O0”.

    • O0: Do not change.

    • O2: Cast network to float16, keep batchnorm and loss_fn (if set) run in float32, using dynamic loss scale.

    • O3: Cast network to float16, with additional property ‘keep_batchnorm_fp32=False’.

    • auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set level to O3 Ascend. The recommended level is choose by the export experience, cannot always generalize. User should specify the level for special network.

    O2 is recommended on GPU, O3 is recommended on Ascend.

  • cast_model_type (mindspore.dtype) – Supports mstype.float16 or mstype.float32. If set to mstype.float16, use float16 mode to train. If set, overwrite the level setting.

  • keep_batchnorm_fp32 (bool) – Keep Batchnorm run in float32. If set, overwrite the level setting. Only cast_model_type is float16, keep_batchnorm_fp32 will take effect.

  • loss_scale_manager (Union[None, LossScaleManager]) – If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting.


Loss scale manager abstract class.

class mindspore.train.loss_scale_manager.DynamicLossScaleManager(init_loss_scale=16777216, scale_factor=2, scale_window=2000)[source]

Dynamic loss-scale manager.

  • init_loss_scale (float) – Initialize loss scale. Default: 2**24.

  • scale_factor (int) – Coefficient of increase and decrease. Default: 2.

  • scale_window (int) – Maximum continuous normal steps when there is no overflow. Default: 2000.


>>> loss_scale_manager = DynamicLossScaleManager()
>>> model = Model(net, loss_scale_manager=loss_scale_manager)

Get the flag whether to drop optimizer update when there is an overflow.


Get loss scale value.


Returns the cell for TrainOneStepWithLossScaleCell


Update loss scale value.


overflow – Boolean. Whether it overflows.

class mindspore.train.loss_scale_manager.FixedLossScaleManager(loss_scale=128.0, drop_overflow_update=True)[source]

Fixed loss-scale manager.

  • loss_scale (float) – Loss scale. Default: 128.0.

  • drop_overflow_update (bool) – whether to execute optimizer if there is an overflow. Default: True.


>>> loss_scale_manager = FixedLossScaleManager()
>>> model = Model(net, loss_scale_manager=loss_scale_manager)

Get the flag whether to drop optimizer update when there is an overflow.


Get loss scale value.


Returns the cell for TrainOneStepWithLossScaleCell


Update loss scale value.


overflow (bool) – Whether it overflows.

class mindspore.train.loss_scale_manager.LossScaleManager[source]

Loss scale manager abstract class.


Get loss scale value.


Get the loss scaling update logic cell.


Update loss scale value.


overflow (bool) – Whether it overflows.



User can use quantization aware to train a model. MindSpore supports quantization aware training, which models quantization errors in both the forward and backward passes using fake-quantization operations. Note that the entire computation is carried out in floating point. At the end of quantization aware training, MindSpore provides conversion functions to convert the trained model into lower precision.

mindspore.train.quant.convert_quant_network(network, bn_fold=True, freeze_bn=10000000, quant_delay=(0, 0), num_bits=(8, 8), per_channel=(False, False), symmetric=(False, False), narrow_range=(False, False))[source]

Create quantization aware training network.

  • network (Cell) – Obtain a pipeline through network for saving graph summary.

  • bn_fold (bool) – Flag to used bn fold ops for simulation inference operation. Default: True.

  • freeze_bn (int) – Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.

  • quant_delay (int, list or tuple) – Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0)

  • num_bits (int, list or tuple) – Number of bits to use for quantize weights and activations. The first element represent weights and second element represent data flow. Default: (8, 8)

  • per_channel (bool, list or tuple) – Quantization granularity based on layer or on channel. If True then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False)

  • symmetric (bool, list or tuple) – Whether the quantization algorithm is symmetric or not. If True then base on symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: (False, False)

  • narrow_range (bool, list or tuple) – Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False)


Cell, Network which has change to quantization aware training network cell.

mindspore.train.quant.export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='AIR')[source]

Exports MindSpore quantization predict model to deploy with AIR.

  • network (Cell) – MindSpore network produced by convert_quant_network.

  • inputs (Tensor) – Inputs of the quantization aware training network.

  • file_name (str) – File name of model to export.

  • mean (int, float) – Input data mean. Default: 127.5.

  • std_dev (int, float) – Input data variance. Default: 127.5.

  • file_format (str) –

    MindSpore currently supports ‘AIR’ and ‘MINDIR’ format for exported quantization aware model. Default: ‘AIR’.

    • AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of Ascend model.

    • MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format for MindSpore models. Recommended suffix for output file is ‘.mindir’.