mindspore.nn.Cell

class mindspore.nn.Cell(auto_prefix=True, flags=None)[source]

The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this base class.

Layers in mindspore.nn are also the subclass of Cell, such as mindspore.nn.Conv2d, and mindspore.nn.ReLU, etc. Cell will be compiled into a calculation graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in PYNATIVE_MODE (dynamic graph mode).

Parameters
  • auto_prefix (bool, optional) – Whether to automatically generate NameSpace for Cell and its child cells. It also affects the names of parameters in the Cell. If set to True, the parameter name will be automatically prefixed, otherwise not. In general, the backbone network should be set to True, otherwise the duplicate name problem will appear. The cell to train the backbone network, such as optimizer and mindspore.nn.TrainOneStepCell, should be set to False, otherwise the parameter name in backbone will be changed by mistake. Default: True.

  • flags (dict, optional) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter. Default: None.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> class MyCell(nn.Cell):
...     def __init__(self, forward_net):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.net = forward_net
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = self.net(x)
...         return self.relu(y)
>>>
>>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> my_net = MyCell(inner_net)
>>> print(my_net.trainable_params())
... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
add_flags(**flags)[source]

Add customized attributes for cell.

This method is also called when the cell class is instantiated and the class parameter ‘flags’ is set to True.

Parameters

flags (dict) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter.

add_flags_recursive(**flags)[source]

If a cell contains child cells, this method can recursively customize attributes of all cells.

Parameters

flags (dict) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter.

apply(fn)[source]

Applies fn recursively to every subcell (as returned by .cells()) as well as self. Typical use includes initializing the parameters of a model.

Parameters

fn (function) – function to be applied to each subcell.

Returns

Cell, self.

Examples

>>> import mindspore.nn as nn
>>> from mindspore.common.initializer import initializer, One
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
>>> def func(cell):
...     if isinstance(cell, nn.Dense):
...         cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
>>> net.apply(func)
SequentialCell<
  (0): Dense<input_channels=2, output_channels=2, has_bias=True>
  (1): Dense<input_channels=2, output_channels=2, has_bias=True>
  >
>>> print(net[0].weight.asnumpy())
[[1. 1.]
 [1. 1.]]
auto_cast_inputs(inputs)[source]

Auto cast inputs in mixed precision scenarios.

Parameters

inputs (tuple) – the inputs of construct.

Returns

Tuple, the inputs after data type cast.

property bprop_debug

Get whether cell custom bprop debug is enabled.

cast_inputs(inputs, dst_type)[source]

Cast inputs to specified type.

Parameters
Returns

tuple[Tensor], the result with destination data type.

cast_param(param)[source]

Cast parameter according to auto mix precision level in pynative mode.

This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.

Parameters

param (Parameter) – Parameters, the type of which should be cast.

Returns

Parameter, the input parameter with type automatically cast.

cells()[source]

Returns an iterator over immediate cells.

Returns

Iteration, the immediate cells in the cell.

cells_and_names(cells=None, name_prefix='')[source]

Returns an iterator over all cells in the network, including the cell’s name and itself.

Parameters
  • cells (str) – Cells to iterate over. Default: None.

  • name_prefix (str) – Namespace. Default: ‘’.

Returns

Iteration, all the child cells and corresponding names in the cell.

Examples

>>> from mindspore import nn
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.conv = nn.Conv2d(3, 64, 3)
...     def construct(self, x):
...         out = self.conv(x)
...         return out
>>> names = []
>>> n = Net()
>>> for m in n.cells_and_names():
...     if m[0]:
...         names.append(m[0])
check_names()[source]

Check the names of cell parameters.

compile(*args, **kwargs)[source]

Compile Cell as a computation graph, the input must be consistent with the input defined in construct.

Parameters
  • args (tuple) – Args of the Cell object.

  • kwargs (dict) – Kwargs of the Cell object.

compile_and_run(*args, **kwargs)[source]

Compile and run Cell, the input must be consistent with the input defined in construct.

Note

It is not recommended to call directly.

Parameters
  • args (tuple) – Args of the Cell object.

  • kwargs (dict) – Kwargs of the Cell object.

Returns

Object, the result of executing.

construct(*args, **kwargs)[source]

Defines the computation to be performed. This method must be overridden by all subclasses.

Note

It is not supported currently that inputs contain both tuple and non-tuple types at same time.

Parameters
  • args (tuple) – Tuple of variable parameters.

  • kwargs (dict) – Dictionary of variable keyword parameters.

Returns

Tensor, returns the computed result.

exec_checkpoint_graph()[source]

Executes saving checkpoint graph operation.

extend_repr()[source]

Expand the description of Cell.

To print customized extended information, re-implement this method in your own cells.

flatten_weights(fusion_size=0)[source]

Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.

Note

By default, parameters with same data type will using a single contiguous memory chunk. but for some models with huge number of parameters, splitting a large memory chunk into several smaller memory chunks has the potential for performance gains, if this is the case, we can use ‘fusion_size’ to limit the maximum memory chunk size.

Parameters

fusion_size (int) – Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.

generate_scope()[source]

Generate the scope for each cell object in the network.

get_flags()[source]

Get the self_defined attributes of the cell, which can be added by add_flags method.

get_func_graph_proto()[source]

Return graph binary proto.

get_inputs()[source]

Returns the dynamic_inputs of a cell object in one network.

Returns

inputs (tuple), Inputs of the Cell object.

Warning

This is an experimental API that is subject to change or deletion.

get_parameters(expand=True)[source]

Returns an iterator over cell parameters.

Yields parameters of this cell. If expand is true, yield parameters of this cell and all subcells.

Parameters

expand (bool) – If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters that are direct members of this cell. Default: True.

Returns

Iteration, all parameters at the cell.

Examples

>>> from mindspore import nn
>>> net = nn.Dense(3, 4)
>>> parameters = []
>>> for item in net.get_parameters():
...     parameters.append(item)
get_scope()[source]

Returns the scope of a cell object in one network.

Returns

String, scope of the cell.

infer_param_pipeline_stage()[source]

Infer pipeline stages of all parameters in the cell.

Note

  • If a parameter does not belong to any cell which has been set pipeline_stage, the parameter should use add_pipeline_stage to add it’s pipeline_stage information.

  • If a parameter P has been used by two operators in different stages “stageA” and “stageB”, the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) to add it’s stage information before using infer_param_pipeline_stage.

Returns

The params belong to current stage in pipeline parallel.

Raises

RuntimeError – If there is a parameter does not belong to any stage.

init_parameters_data(auto_parallel_mode=False)[source]

Initialize all parameters and replace the original saved parameters in cell.

Note

trainable_params() and other similar interfaces may return different parameter instance after init_parameters_data, do not save these results.

Parameters

auto_parallel_mode (bool) – If running in auto_parallel_mode. Default: False.

Returns

Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.

insert_child_to_cell(child_name, child_cell)[source]

Adds a child cell to the current cell with a given name.

Parameters
  • child_name (str) – Name of the child cell.

  • child_cell (Cell) – The child cell to be inserted.

Raises
  • KeyError – Child Cell’s name is incorrect or duplicated with the other child name.

  • TypeError – If type of child_name is not str.

  • TypeError – Child Cell’s type is incorrect.

insert_param_to_cell(param_name, param, check_name_contain_dot=True)[source]

Adds a parameter to the current cell.

Inserts a parameter with given name to the cell. The method is currently used in mindspore.nn.Cell.__setattr__.

Parameters
  • param_name (str) – Name of the parameter.

  • param (Parameter) – Parameter to be inserted to the cell.

  • check_name_contain_dot (bool) – Determines whether the name input is compatible. Default: True.

Raises
  • KeyError – If the name of parameter is null or contains dot.

  • TypeError – If the type of parameter is not Parameter.

name_cells()[source]

Returns an iterator over all immediate cells in the network.

Include name of the cell and cell itself.

Returns

Dict, all the child cells and corresponding names in the cell.

property param_prefix

Param prefix is the prefix of current cell’s direct child parameter.

property parameter_layout_dict

parameter_layout_dict represents the tensor layout of a parameter, which is inferred by shard strategy and distributed operator information.

parameters_and_names(name_prefix='', expand=True)[source]

Returns an iterator over cell parameters.

Includes the parameter’s name and itself.

Parameters
  • name_prefix (str) – Namespace. Default: ‘’.

  • expand (bool) – If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters that are direct members of this cell. Default: True.

Returns

Iteration, all the names and corresponding parameters in the cell.

Examples

>>> from mindspore import nn
>>> n = nn.Dense(3, 4)
>>> names = []
>>> for m in n.parameters_and_names():
...     if m[0]:
...         names.append(m[0])
parameters_broadcast_dict(recurse=True)[source]

Gets the parameters broadcast dictionary of this cell.

Parameters

recurse (bool) – Whether contains the parameters of subcells. Default: True.

Returns

OrderedDict, return parameters broadcast dictionary.

parameters_dict(recurse=True)[source]

Gets the parameters dictionary of this cell.

Parameters

recurse (bool) – Whether contains the parameters of subcells. Default: True.

Returns

OrderedDict, return parameters dictionary.

place(role, rank_id)[source]

Set the label for all operators in this cell. This label tells MindSpore compiler on which process this cell should be launched. And each process’s identical label consists of input role and rank_id. So by setting different cells with different labels, which will be launched on different processes, users can launch a distributed training or predicting job.

Note

  • This method is effective only after mindspore.communication.init() is called for dynamic cluster building.

Parameters
  • role (str) – The role of the process on which this cell will be launched. Only ‘MS_WORKER’ is supported for now.

  • rank_id (int) – The rank id of the process on which this cell will be launched. The rank is unique in processes with the same role.

Examples

>>> from mindspore import context
>>> import mindspore.nn as nn
>>> context.set_context(mode=context.GRAPH_MODE)
>>> fc = nn.Dense(2, 3)
>>> fc.place('MS_WORKER', 0)
recompute(**kwargs)[source]

Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed. If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.

Note

  • If the computation involves something like randomization or global variable, the equivalence is not guaranteed currently.

  • If the recompute api of a primitive in this cell is also called, the recompute mode of this primitive is subject to the recompute api of the primitive.

  • The interface can be configured only once. Therefore, when the parent cell is configured, the child cell should not be configured.

  • The outputs of cell are excluded from recomputation by default, which is based on our configuration experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted to be set recomputed, use the recompute api of the primtive.

  • When the memory remains after applying the recomputation, configuring ‘mp_comm_recompute=False’ to improve performance if necessary.

  • When the memory still not enough after applying the recompute, configuring ‘parallel_optimizer_comm_recompute=True’ to save more memory if necessary. Cells in the same fusion group should have the same parallel_optimizer_comm_recompute configures.

Parameters
  • mp_comm_recompute (bool) – Specifies whether the model parallel communication operators in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.

  • parallel_optimizer_comm_recompute (bool) – Specifies whether the communication operator allgathers introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.

register_backward_hook(hook_fn)[source]

Register the backward hook function.

Note

  • The register_backward_hook(hook_fn) does not work in graph mode or functions decorated with ‘jit’.

  • The ‘hook_fn’ must be defined as the following code. cell_id is the information of registered Cell object, including name and ID. grad_input is the gradient passed to the Cell. grad_output is the gradient computed and passed to the next Cell or primitive, which may be modified by returning a new output gradient.

  • The ‘hook_fn’ should have the following signature: hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.

  • The ‘hook_fn’ is executed in the python environment. In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_backward_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.

Parameters

hook_fn (function) – Python function. Backward hook function.

Returns

Handle, it is an instance of mindspore.common.hook_handle.HookHandle and corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .

Raises

TypeError – If the hook_fn is not a function of python.

Supported Platforms: Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.ops import GradOperation
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def backward_hook_fn(cell_id, grad_input, grad_output):
...     print("backward input: ", grad_input)
...     print("backward output: ", grad_output)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...         self.handle = self.relu.register_backward_hook(backward_hook_fn)
...
...     def construct(self, x):
...         x = x + x
...         x = self.relu(x)
...         return x
>>> grad = GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
register_forward_hook(hook_fn)[source]

Set the Cell forward hook function.

Note

  • The register_forward_hook(hook_fn) does not work in graph mode or functions decorated with ‘jit’.

  • ‘hook_fn’ must be defined as the following code. cell_id is the information of registered Cell object, including name and ID. inputs is the forward input objects passed to the Cell. output is the forward output object of the Cell. The ‘hook_fn’ can modify the forward output object by returning new forward output object.

  • It should have the following signature: hook_fn(cell_id, inputs, output) -> new output object or none.

  • In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_forward_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.

Parameters

hook_fn (function) – Python function. Forward hook function.

Returns

Handle, it is an instance of mindspore.common.hook_handle.HookHandle and corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .

Raises

TypeError – If the hook_fn is not a function of python.

Supported Platforms: Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.ops import GradOperation
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def forward_hook_fn(cell_id, inputs, output):
...     print("forward inputs: ", inputs)
...     print("forward output: ", output)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.mul = nn.MatMul()
...         self.handle = self.mul.register_forward_hook(forward_hook_fn)
...
...     def construct(self, x, y):
...         x = x + x
...         x = self.mul(x, y)
...         return x
>>> grad = GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
                dtype=Float32, value= [ 1.00000000e+00]))
forward output: 2.0
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
value= [ 2.00000000e+00]))
register_forward_pre_hook(hook_fn)[source]

Register forward pre hook function for Cell object.

Note

  • The register_forward_pre_hook(hook_fn) does not work in graph mode or functions decorated with ‘jit’.

  • ‘hook_fn’ must be defined as the following code. cell_id is the information of registered Cell object, including name and ID. inputs is the forward input objects passed to the Cell. The ‘hook_fn’ can modify the forward input objects by returning new forward input objects.

  • It should have the following signature: hook_fn(cell_id, inputs) -> new input objects or none.

  • In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_forward_pre_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.

Parameters

hook_fn (function) – Python function. Forward pre hook function.

Returns

Handle, it is an instance of mindspore.common.hook_handle.HookHandle and corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .

Raises

TypeError – If the hook_fn is not a function of python.

Supported Platforms: Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.ops import GradOperation
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def forward_pre_hook_fn(cell_id, inputs):
...     print("forward inputs: ", inputs)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.mul = nn.MatMul()
...         self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn)
...
...     def construct(self, x, y):
...         x = x + x
...         x = self.mul(x, y)
...         return x
>>> grad = GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
                dtype=Float32, value= [ 1.00000000e+00]))
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
value= [ 2.00000000e+00]))
remove_redundant_parameters()[source]

Remove the redundant parameters.

This interface usually needs not to be used explicitly.

run_construct(cast_inputs, kwargs)[source]

Run the construct function.

Note

This function will be removed in a future version. It is not recommended to call this function.

Parameters
  • cast_inputs (tuple) – The input objects of Cell.

  • kwargs (dict) – Provide keyword arguments.

Returns

output, the output object of Cell.

set_boost(boost_type)[source]

In order to improve the network performance, configure the network auto enable to accelerate the algorithm in the algorithm library.

If boost_type is not in the algorithm library, please view the algorithm in the algorithm library through algorithm library.

Note

Some acceleration algorithms may affect the accuracy of the network, please choose carefully.

Parameters

boost_type (str) – accelerate algorithm.

Returns

Cell, the cell itself.

Raises

ValueError – If boost_type is not in the algorithm library.

set_broadcast_flag(mode=True)[source]

Set parameter broadcast mode for this cell.

Parameters

mode (bool) – Specifies whether the mode is parameter broadcast. Default: True.

set_comm_fusion(fusion_type, recurse=True)[source]

Set comm_fusion for all the parameters in this cell. Please refer to the description of mindspore.Parameter.comm_fusion.

Note

The value of attribute will be overwritten when the function is called multiply.

Parameters
  • fusion_type (int) – The value of comm_fusion.

  • recurse (bool) – Whether sets the trainable parameters of subcells. Default: True.

set_data_parallel()[source]

For all primitive ops in this cell(including ops of cells that wrapped by this cell), if parallel strategy is not specified, then instead of auto-searching, data parallel strategy will be generated for those primitive ops.

Note

Only effective while using auto_parallel_context = ParallelMode.AUTO_PARALLEL under graph mode.

Examples

>>> import mindspore.nn as nn
>>> net = nn.Dense(3, 4)
>>> net.set_data_parallel()
set_grad(requires_grad=True)[source]

Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires gradients. If true, the backward network needed to compute the gradients will be generated when the forward network is executed.

Parameters

requires_grad (bool) – Specifies if the net need to grad, if it is true, the cell will construct backward network in pynative mode. Default: True.

Returns

Cell, the cell itself.

set_inputs(*inputs)[source]

Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are configured with set_inputs. The inputs can be Tensor of either dynamic or static shape.

Parameters

inputs (tuple) – Inputs of the Cell object.

Warning

This is an experimental API that is subject to change or deletion.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, Tensor, context
>>>
>>> class reluNet(nn.Cell):
...     def __init__(self):
...         super(reluNet, self).__init__()
...         self.relu = nn.ReLU()
...     def construct(self, x):
...         return self.relu(x)
>>>
>>> net = reluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
>>> net.set_inputs(input_dyn)
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
>>> output = net(input1)
set_jit_config(jit_config)[source]

Set jit config for cell.

Parameters

jit_config (JitConfig) – Jit config for compile. For details, please refer to mindspore.JitConfig.

set_param_ps(recurse=True, init_in_server=False)[source]

Set whether the trainable parameters are updated by parameter server and whether the trainable parameters are initialized on server.

Note

It only works when a running task is in the parameter server mode. It is only supported in graph mode.

Parameters
  • recurse (bool) – Whether sets the trainable parameters of subcells. Default: True.

  • init_in_server (bool) – Whether trainable parameters updated by parameter server are initialized on server. Default: False.

set_train(mode=True)[source]

Sets the cell to training mode.

The cell itself and all children cells will be set to training mode. Layers that have different constructions for training and predicting, such as BatchNorm, will distinguish between the branches by this attribute. If set to true, the training branch will be executed, otherwise another branch.

Note

When execute function Model.train(), framework will call Cell.set_train(True). When execute function Model.eval(), framework will call Cell.set_train(False).

Parameters

mode (bool) – Specifies whether the model is training. Default: True.

Returns

Cell, the cell itself.

shard(in_strategy, out_strategy=None, parameter_plan=None, device='Ascend', level=0)[source]

Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed execution in graph mode. in_strategy and out_strategy define the input and output layout respectively. in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of this input/output, and None represents data_parallel, which can refer to the description of mindspore.ops.Primitive.shard. The parallel strategies of remaining operators are derived from the strategy specified by the input and output.

Note

Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL with search_mode in auto_parallel_context set as sharding_propagation. If the input contain Parameter, its strategy should be set in in_strategy.

Parameters
  • in_strategy (tuple) – Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple defines the layout of the corresponding input and None represents a data parallel strategy.

  • out_strategy (Union[None, tuple]) – Define the layout of outputs similar with in_strategy. It is not in use right now. Default: None.

  • parameter_plan (Union[dict, None]) – Define the layout for the specified parameters. Each element in dict defines the layout of the parameter like “param_name: layout”. The key is a parameter name of type ‘str’. The value is a 1-D integer tuple, indicating the corresponding layout. If the parameter name is incorrect or the corresponding parameter has been set, the parameter setting will be ignored. Default: None.

  • device (string) – Select a certain device target. It is not in use right now. Support [“CPU”, “GPU”, “Ascend”]. Default: “Ascend”.

  • level (int) – Option for parallel strategy infer algorithm, namely the object function, maximize computation over communication ratio, maximize speed performance, minimize memory usage etc. It is not in use right now. Support [“0”, “1”, “2”]. Default: 0.

Returns

Cell, the cell itself.

Examples

>>> import mindspore.nn as nn
>>>
>>> class Block(nn.Cell):
...   def __init__(self):
...     self.dense1 = nn.Dense(10, 10)
...     self.relu = nn.ReLU()
...     self.dense2 = nn.Dense2(10, 10)
...   def construct(self, x):
...     x = self.relu(self.dense2(self.relu(self.dense1(x))))
...     return x
>>>
>>> class example(nn.Cell):
...   def __init__(self):
...     self.block1 = Block()
...     self.block2 = Block()
...     self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
...                       parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
...   def construct(self, x):
...     x = self.block1(x)
...     x = self.block2(x)
...     return x
to_float(dst_type)[source]

Add cast on all inputs of cell and child cells to run with certain float type.

If dst_type is mindspore.dtype.float16, all the inputs of Cell, including input, Parameter and Tensor, will be cast to float16. Please refer to the usage in source code of mindspore.amp.build_train_network().

Note

Multiple calls will overwrite.

Parameters

dst_type (mindspore.dtype) – Transfer cell to run with dst_type. dst_type can be mstype.float16 or mstype.float32.

Returns

Cell, the cell itself.

Raises

ValueError – If dst_type is not mstype.float32 or mstype.float16.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore.nn as nn
>>> from mindspore import dtype as mstype
>>>
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> net.to_float(mstype.float16)
Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
trainable_params(recurse=True)[source]

Returns all trainable parameters.

Returns a list of all trainable parameters.

Parameters

recurse (bool) – Whether contains the trainable parameters of subcells. Default: True.

Returns

List, the list of trainable parameters.

untrainable_params(recurse=True)[source]

Returns all untrainable parameters.

Returns a list of all untrainable parameters.

Parameters

recurse (bool) – Whether contains the untrainable parameters of subcells. Default: True.

Returns

List, the list of untrainable parameters.

update_cell_prefix()[source]

Update the param_prefix of all child cells.

After being invoked, it can get all the cell’s children’s name prefix by ‘_param_prefix’.

update_cell_type(cell_type)[source]

The current cell type is updated when a quantization aware training network is encountered.

After being invoked, it can set the cell type to ‘cell_type’.

Parameters

cell_type (str) – The type of cell to be updated, cell_type can be “quant” or “second-order”.

update_parameters_name(prefix='', recurse=True)[source]

Adds the prefix string to the names of parameters.

Parameters
  • prefix (str) – The prefix string. Default: ‘’.

  • recurse (bool) – Whether contains the parameters of subcells. Default: True.