Source code for mindspore.train.model

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model."""
import numpy as np

from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
    _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
from ..nn.metrics import Loss
from ..nn.wrap import WithLossCell, WithEvalCell, \
    DataWrapper
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp


[docs]class Model: """ High-Level API for Training or Testing. `Model` groups layers into an object with training and inference features. Args: network (Cell): The training or testing network. loss_fn (Cell): Objective function, if loss_fn is None, the network should contain the logic of loss and grads calculation, and the logic of parallel if needed. Default: None. optimizer (Cell): Optimizer for updating the weights. Default: None. metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during training and testing. eg: {'accuracy', 'recall'}. Default: None. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as `eval_network`. Default: None. eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three elements, representing the positions of loss value, predict value and label, the loss value would be passed to `Loss` metric, predict value and label would be passed to other metric. Default: None. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed precision training. Supports [O0, O2]. Default: "O0". - O0: Do not change. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. e.g. Use `loss_scale_manager=None` to set the value. Examples: >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') >>> self.bn = nn.BatchNorm2d(64) >>> self.relu = nn.ReLU() >>> self.flatten = nn.Flatten() >>> self.fc = nn.Dense(64*222*222, 3) # padding=0 >>> >>> def construct(self, x): >>> x = self.conv(x) >>> x = self.bn(x) >>> x = self.relu(x) >>> x = self.flatten(x) >>> out = self.fc(x) >>> return out >>> >>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> dataset = get_dataset() >>> model.train(2, dataset) """ def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", **kwargs): self._network = network self._loss_fn = loss_fn self._optimizer = optimizer self._loss_scale_manager = None self._loss_scale_manager_set = False self._check_kwargs(kwargs) if 'loss_scale_manager' in kwargs: self._loss_scale_manager = kwargs['loss_scale_manager'] self._loss_scale_manager_set = True self._amp_level = amp_level self._parallel_mode = _get_parallel_mode() self._device_number = _get_device_num() self._global_rank = _get_global_rank() self._parameter_broadcast = _get_parameter_broadcast() self._train_network = self._build_train_network() self._build_eval_network(metrics, eval_network, eval_indexes) def _check_kwargs(self, kwargs): for arg in kwargs: if arg not in ['loss_scale_manager']: raise ValueError(f"Unsupport arg '{arg}'") def _build_train_network(self): """Build train network""" network = self._network if self._optimizer: if self._loss_scale_manager_set: network = amp.build_train_network(network, self._optimizer, self._loss_fn, level=self._amp_level, loss_scale_manager=self._loss_scale_manager) else: network = amp.build_train_network(network, self._optimizer, self._loss_fn, level=self._amp_level) elif self._loss_fn: network = WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None return network def _build_eval_network(self, metrics, eval_network, eval_indexes): """Build the network for evaluation.""" self._metric_fns = get_metrics(metrics) if not self._metric_fns: return if eval_network is not None: if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3): raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \ must be three. But got {}".format(eval_indexes)) self._eval_network = eval_network self._eval_indexes = eval_indexes else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") self._eval_network = WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] def _clear_metrics(self): """Clear metrics local values.""" for metric in self._metric_fns.values(): metric.clear() def _update_metrics(self, outputs): """Update metrics local values.""" if self._eval_indexes is not None and len(outputs) < 3: raise ValueError("The length of `outputs` must be greater than or equal to 3, \ but got {}".format(len(outputs))) for metric in self._metric_fns.values(): if self._eval_indexes is None: metric.update(*outputs) else: if isinstance(metric, Loss): metric.update(outputs[self._eval_indexes[0]]) else: metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]]) def _get_metrics(self): """Get metrics local values.""" metrics = dict() for key, value in self._metric_fns.items(): metrics[key] = value.eval() return metrics def _get_scaling_sens(self): """get the scaling sens""" scaling_sens = 1 if self._loss_scale_manager is not None: scaling_sens = self._loss_scale_manager.get_loss_scale() if self._parallel_mode == ParallelMode.DATA_PARALLEL: scaling_sens /= self._device_number return scaling_sens def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ Training. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be returned and passed to the network. Otherwise, a tuple (data, label) will be returned, and the data and label are passed to the network and loss function respectively. callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. """ epoch = check_int_positive(epoch) self._train_network.set_train() if self._parameter_broadcast: self._train_network.set_broadcast_flag() # build callback list list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch cb_params.batch_num = train_dataset.get_dataset_size() cb_params.mode = "train" cb_params.loss_fn = self._loss_fn cb_params.optimizer = self._optimizer cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset cb_params.list_callback = list_callback if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) else: self._train_process(epoch, train_dataset, list_callback, cb_params) def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ Training process. The data would be passed to network through dataset channel. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ # remove later to deal with loop sink need_wrap = False if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"): need_wrap = True dataset_helper = DatasetHelper(train_dataset) # remove later to deal with loop sink if need_wrap: self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()), train_dataset.__ME_INITED__) cb_params.train_network = self._train_network self._train_network.set_train() cb_params.cur_step_num = 0 loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) _callback_wrapper(list_callback, run_context, "begin") # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False for i in range(epoch): cb_params.cur_epoch_num = i + 1 _callback_wrapper(list_callback, run_context, "epoch_begin") # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: cb_params.cur_step_num += loop_size _callback_wrapper(list_callback, run_context, "step_begin") outputs = self._train_network(*inputs) cb_params.net_outputs = outputs _callback_wrapper(list_callback, run_context, "step_end") _callback_wrapper(list_callback, run_context, "epoch_end") should_stop = should_stop or run_context.get_stop_requested() if should_stop: break _callback_wrapper(list_callback, run_context, "end") def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ Training process. The data would be passed to network directly. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) cb_params.cur_step_num = 0 run_context = RunContext(cb_params) _callback_wrapper(list_callback, run_context, "begin") # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False for i in range(epoch): cb_params.cur_epoch_num = i + 1 _callback_wrapper(list_callback, run_context, "epoch_begin") for next_element in dataset_helper: len_element = len(next_element) if self._loss_fn and len_element != 2: raise ValueError("when loss_fn is not None, train_dataset should" "return two elements, but got {}".format(len_element)) cb_params.cur_step_num += 1 _callback_wrapper(list_callback, run_context, "step_begin") overflow = False if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): scaling_sens = self._get_scaling_sens() next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) outputs = self._train_network(*next_element) cb_params.net_outputs = outputs if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): _, overflow, _ = outputs overflow = np.all(overflow.asnumpy()) self._loss_scale_manager.update_loss_scale(overflow) _callback_wrapper(list_callback, run_context, "step_end") should_stop = should_stop or run_context.get_stop_requested() if should_stop: break train_dataset.reset() _callback_wrapper(list_callback, run_context, "epoch_end") should_stop = should_stop or run_context.get_stop_requested() if should_stop: break _callback_wrapper(list_callback, run_context, "end")
[docs] def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ Training API where the iteration is controlled by python front-end. Configure to pynative mode, the training will be performed with dataset non-sink mode. Note: CPU is not supported when dataset_sink_mode is true. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Examples: >>> dataset = get_dataset() >>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss_scale_manager = FixedLossScaleManager() >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) >>> model.train(2, dataset) """ repeat_count = train_dataset.get_repeat_count() if epoch != repeat_count: logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}") check_bool(dataset_sink_mode) _device_number_check(self._parallel_mode, self._device_number) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"): raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.") self._train(epoch, train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network through dataset channel. Args: valid_dataset (Dataset): Dataset to evaluate the model. list_callback (ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: Dict, returns the loss value & metrics values for the model in test mode. """ _device_number_check(self._parallel_mode, self._device_number) run_context = RunContext(cb_params) # remove later to deal with loop sink need_wrap = False if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"): need_wrap = True valid_dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(valid_dataset) # remove later to deal with loop sink if need_wrap: self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), valid_dataset.__ME_INITED__) self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' list_callback.begin(run_context) for inputs in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) self._update_metrics(outputs) metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) return metrics def _eval_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network directly. Args: valid_dataset (Dataset): Dataset to evaluate the model. list_callback (ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: Dict, returns the loss value & metrics values for the model in test mode. """ run_context = RunContext(cb_params) list_callback.begin(run_context) dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) for next_element in dataset_helper: list_callback.step_begin(run_context) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context) self._update_metrics(outputs) metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) return metrics
[docs] def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): """ Evaluation API where the iteration is controlled by python front-end. Configure to pynative mode, the evaluation will be performed with dataset non-sink mode. Note: CPU is not supported when dataset_sink_mode is true. Args: valid_dataset (Dataset): Dataset to evaluate the model. callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Returns: Dict, returns the loss value & metrics values for the model in test mode. Examples: >>> dataset = get_dataset() >>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model.eval(dataset) """ check_bool(dataset_sink_mode) if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset cb_params.batch_num = valid_dataset.get_dataset_size() cb_params.mode = "eval" cb_params.cur_step_num = 0 self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' self._clear_metrics() if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE: return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) return self._eval_process(valid_dataset, list_callback, cb_params)
[docs] def predict(self, *predict_data): """ Generates output predictions for the input samples. Data could be single tensor, or list of tensor, tuple of tensor. Note: Batch data should be put together in one tensor. Args: predict_data (Tensor): Tensor of predict data. can be array, list or tuple. Returns: Tensor, array(s) of predictions. Examples: >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mstype.float32) >>> model = Model(Net()) >>> model.predict(input_data) """ if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): self._network = _VirtualDatasetCell(self._network) self._network.set_train(False) check_input_data(*predict_data, data_class=Tensor) result = self._network(*predict_data) check_output_data(result) return result
__all__ = ["Model"]