Source code for mindspore.nn.wrap.cell_wrapper

# Copyright 2020 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
from mindspore.train.parallel_utils import ParallelMode
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite.base import _mp_cast_helper
from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer

[docs]class WithLossCell(Cell): r""" Cell with loss function. Wraps the network with loss function. This Cell accepts data and label as inputs and the computed loss will be returned. Args: backbone (Cell): The target network to wrap. loss_fn (Cell): The loss function used to compute loss. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. Outputs: Tensor, a scalar tensor with shape :math:`()`. Examples: >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> net_with_criterion = nn.WithLossCell(net, loss_fn) >>> >>> batch_size = 2 >>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01) >>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32)) >>> >>> net_with_criterion(data, label) """ def __init__(self, backbone, loss_fn): super(WithLossCell, self).__init__(auto_prefix=False) self._backbone = backbone self._loss_fn = loss_fn def construct(self, data, label): out = self._backbone(data) return self._loss_fn(out, label) @property def backbone_network(self): """ Returns the backbone network. Returns: Cell, the backbone network. """ return self._backbone
[docs]class WithGradCell(Cell): r""" Cell that returns the gradients. Wraps the network with backward cell to compute gradients. A network with a loss function is necessary as argument. If loss function in None, the network must be a wrapper of network and loss function. This Cell accepts data and label as inputs and returns gradients for each trainable parameter. Note: Run in PyNative mode. Args: network (Cell): The target network to wrap. loss_fn (Cell): Primitive loss function used to compute gradients. Default: None. sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape should be same as the `network` output. If None, we will fill one to a same type shape of output value. Default: None. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. Outputs: list, a list of Tensors with identical shapes as trainable weights. Examples: >>> # For a defined network Net without loss function >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> grad_net = nn.WithGradCell(net, loss_fn) >>> >>> # For a network wrapped with loss function >>> net = Net() >>> net_with_criterion = nn.WithLossCell(net, loss_fn) >>> grad_net = nn.WithGradCell(net_with_criterion) """ def __init__(self, network, loss_fn=None, sens=None): super(WithGradCell, self).__init__(auto_prefix=False) = network self.loss_fn = loss_fn self.weights = ParameterTuple(network.trainable_params()) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=(sens is not None)) self.sens = sens if loss_fn is None: self.network_with_loss = network else: self.network_with_loss = WithLossCell(, self.loss_fn) self.network_with_loss.set_train() def construct(self, data, label): weights = self.weights if self.sens is None: grads = self.grad(self.network_with_loss, weights)(data, label) else: grads = self.grad(self.network_with_loss, weights)(data, label, self.sens) return grads
[docs]class TrainOneStepCell(Cell): r""" Network training package class. Wraps the network with an optimizer. The resulting Cell be trained with input data and label. Backward graph will be created in the construct function to do parameter updating. Different parallel modes are available to run the training. Args: network (Cell): The training network. optimizer (Cell): Optimizer for updating the weights. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. Outputs: Tensor, a scalar Tensor with shape :math:`()`. Examples: >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> loss_net = nn.WithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) """ def __init__(self, network, optimizer, sens=1.0): super(TrainOneStepCell, self).__init__(auto_prefix=False) = network self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.sens = sens self.reducer_flag = False self.grad_reducer = None parallel_mode = _get_parallel_mode() if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: mean = _get_mirror_mean() degree = _get_device_num() self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) def construct(self, data, label): weights = self.weights loss =, label) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(, weights)(data, label, sens) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) return F.depend(loss, self.optimizer(grads))
[docs]class DataWrapper(Cell): """ Network training package class for dataset. DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the dataset channel 'queue_name' and does forward computation in the construct function. Args: network (Cell): The training network for dataset. dataset_types (list): The type of dataset. The list contains describes the types of the inputs. dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describes the shape of the inputs. queue_name (str): The identification of dataset channel which specifies the dataset channel to supply data for the network. Outputs: Tensor, network output whose shape depends on the network. Examples: >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset >>> train_dataset = create_dataset() >>> dataset_helper = mindspore.DatasetHelper(train_dataset) >>> net = Net() >>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name) """ def __init__(self, network, dataset_types, dataset_shapes, queue_name): super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) = network def construct(self): outputs = self.get_next() return*outputs)
[docs]class GetNextSingleOp(Cell): """ Cell to run get next operation. Args: dataset_types (list[:class:`mindspore.dtype`]): The types of dataset. dataset_shapes (list[tuple[int]]): The shapes of dataset. queue_name (str): Queue name to fetch the data. Detailed information, please refer to `ops.operations.GetNext`. """ def __init__(self, dataset_types, dataset_shapes, queue_name): super(GetNextSingleOp, self).__init__() self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) def construct(self): return self.get_next()
class _VirtualDatasetCell(Cell): """ Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outpus of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted dynamically during the graph compile process. Note: Only used in semi auto parallel and auto parallel mode. Args: backbone (Cell): The target network to wrap. Examples: >>> net = Net() >>> net = _VirtualDatasetCell(net) """ def __init__(self, backbone): super(_VirtualDatasetCell, self).__init__(auto_prefix=False) self._backbone = backbone self._virtual_dataset = _VirtualDataset() def construct(self, data, label): data_, label_ = self._virtual_dataset(data, label) return self._backbone(data_, label_)
[docs]class WithEvalCell(Cell): r""" Cell that returns loss, output and label for evaluation. This Cell accepts a network and loss function as arguments and computes loss for model. It returns loss, output and label to calculate the metrics. Args: network (Cell): The network Cell. loss_fn (Cell): The loss Cell. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. Outputs: Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)` and a label Tensor of shape :math:`(N, \ldots)`. Examples: >>> # For a defined network Net without loss function >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> eval_net = nn.WithEvalCell(net, loss_fn) """ def __init__(self, network, loss_fn, add_cast_fp32=False): super(WithEvalCell, self).__init__(auto_prefix=False) self._network = network self._loss_fn = loss_fn self.add_cast_fp32 = add_cast_fp32 def construct(self, data, label): outputs = self._network(data) if self.add_cast_fp32: label = _mp_cast_helper(mstype.float32, label) outputs = F.cast(outputs, mstype.float32) loss = self._loss_fn(outputs, label) return loss, outputs, label
[docs]class ParameterUpdate(Cell): """ Cell that updates parameters. With this Cell, one can manually update `param` with the input `Tensor`. Args: param (Parameter): The parameter to be updated manually. Raises: KeyError: If parameter with the specified name do not exist. Examples: >>> network = Net() >>> param = network.parameters_dict()['learning_rate'] >>> update = nn.ParameterUpdate(param) >>> update.phase = "update_param" >>> lr = Tensor(0.001, mindspore.float32) >>> update(lr) """ def __init__(self, param): super(ParameterUpdate, self).__init__(auto_prefix=False) if not isinstance(param, Parameter): raise TypeError("`param` must be `Parameter`, but got {}".format(param)) self._param = param def construct(self, x): F.assign(self._param, x) return x