mindspore.boost

Boost provide auto accelerating for network, such as Less BN, Gradient Freeze, Gradient accumulation and so on.

Note

This feature is a beta feature, and we are still improving its functionality.

class mindspore.boost.AdaSum(rank, device_number, group_number, parameter_tuple)[source]

The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data parallel training of Deep Learning models.

Parameters
  • rank (int) – Rank number.

  • device_number (int) – Device number.

  • group_number (int) – Group number.

  • parameter_tuple (Tuple(Parameter)) – Tuple of parameters.

Inputs:
  • delta_weights (Tuple(Tensor)) - Tuple of gradients.

  • parameters (Tuple(Parameter)) - Tuple of current parameters.

  • old_parameters (Tuple(Parameter)) - Tuple of last parameters.

Outputs:
  • adasum_parameters (Tuple(Tensor)) - Tuple of parameters after adasum process.

class mindspore.boost.AutoBoost(level='O0', boost_config_dict='')[source]

Provide auto accelerating for network.

Parameters
  • level (str) – Boost config level. Default: “O0”.

  • boost_config_dict (dict) –

    User config hyperparameter dict, recommended config format:

    {
        "boost": {
            "mode": "auto",
            "less_bn": False,
            "grad_freeze": False,
            "adasum": False,
            "grad_accumulation": False,
            "dim_reduce": False,
            "loss_scale_group": False
        },
        "common": {
            "gradient_split_groups": [50, 100],
            "device_number": 8
        },
        "less_bn": {
            "fn_flag": True,
            "gc_flag": True
        },
        "grad_freeze": {
            "param_groups": 10,
            "freeze_type": 1,
            "freeze_p": 0.7,
            "total_steps": 65536
        }
        "grad_accumulation": {
            "grad_accumulation_step": 1
        },
        "dim_reduce": {
            "rho": 0.55,
            "gamma": 0.9,
            "alpha": 0.001,
            "sigma": 0.4,
            "n_components": 32,
            "pca_mat_path": None,
            "weight_load_dir": None,
            "timeout": 1800
        }
    }
    
    • boost:

      • mode (str): How to set the boost. Supports [“auto”, “manual”, “enable_all”, “disable_all”]. Default: “auto”.

        • auto: Depend on the argument “boost_level” in class Model.

        • manual: Depend on “boost_config_dict”.

        • enable_all: Set all boost functions true.

        • disable_all: Set all boost functions false.

      • less_bn (bool): Whether to apply less_bn function. Default: False.

      • grad_freeze: (bool): Whether to apply grad_freeze function. Default: False.

      • adasum (bool): Whether to apply adasum function. Default: False.

      • grad_accumulation (bool): Whether to apply grad_accumulation function. Default: False.

      • dim_reduce (bool): Whether to apply dim_reduce function. Default: False.

      • loss_scale_group (bool): Whether to apply loss_scale_group function. Default: False.

      If set dim_reduce true, other functions will be false. If set grad_freeze true and dim_reduce false, other functions will be false.

    • common:

      • gradient_split_groups (list): The gradient split point of this network. Default: [50, 100].

      • device_number (int): Device number. Default: 8.

    • less_bn:

      • fn_flag (bool): Whether changing fc to fn. Default: True.

      • gc_flag (bool): Whether to apply gc. Default: True.

    • grad_freeze:

      • param_groups (int): The number of parameter groups. Default: 10.

      • freeze_type (int): Gradient freeze grouping strategy, select from [0, 1]. Default: 1.

      • freeze_p (float): Gradient freezing probability. Default: 0.7.

      • total_steps (int): Total training steps. Default: 65536.

    • grad_accumulation:

      • grad_accumulation_step (int): Steps to accumulate gradients. Default: 1.

    • dim_reduce:

      The leading principles of dim_reduce:

      \[\begin{split}\begin{align} grad\_k &= pca\_mat \cdot grad\\ dk &= - bk \cdot grad\_k\\ sk &= rho ^ m \cdot dk\\ delta\_loss &= sigma \cdot grad\_k.T \cdot sk \end{align}\end{split}\]

      Here:

      • pca_mat (array): Shape \((k*n)\), k is part of n_components, n is the size of weight.

      • bk (array): Shape \((k*k)\), is the symmetric positive definite matrix in Quasi-Newton method.

      we need to find the m satisfy:

      \[new\_loss < old\_loss + delta\_loss\]

      Then, get delta_grad to update the weights for model:

      \[\begin{split}\begin{align} grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\ new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\ delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk \end{align}\end{split}\]
      • rho (float): Generally, it does not need to be modified. Default: 0.55.

      • gamma (float): Generally, it does not need to be modified. Default: 0.9.

      • alpha (float): Generally, it does not need to be modified. Default: 0.001.

      • sigma (float): Generally, it does not need to be modified. Default: 0.4.

      • n_components (int): PCA component. Default: 32.

      • pca_mat_path (str): The path to load pca mat. Default: None.

      • weight_load_dir (str): The directory to load weight files saved as ckpt. Default: None.

      • timeout (int): Waiting time to load local pca mat. Default: 1800 (second).

    User can load the config through the JSON file or use the dictionary directly. The unconfigured parameters will adopt the default values.

Raises

ValueError – The boost mode not in [“auto”, “manual”, “enable_all”, “disable_all”].

Supported Platforms:

Ascend

Examples

>>> from mindspore.boost import AutoBoost
>>> #1) when configuring the dict directly:
>>> boost_config_dict = {"boost": {"mode": "auto"}}
>>> boost = AutoBoost("O1", boost_config_dict)
>>>
>>> #2) when loading the dict from a json file:
>>> import json
>>> boost_json = "/path/boost_config.json"
>>> with open(boost_json, 'r') as fp:
>>>     boost_config_dict = json.load(fp)
>>> boost = AutoBoost("O1", boost_config_dict)
network_auto_process_eval(network)[source]

Boost network eval.

Parameters

network (Cell) – The inference network.

network_auto_process_train(network, optimizer)[source]

Boost network train.

Parameters
  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

class mindspore.boost.BoostTrainOneStepCell(network, optimizer, sens=1.0)[source]

Boost Network training package class.

Wraps the network with an optimizer. The resulting Cell is trained with input ‘*inputs’. The backward graph will be created in the construct function to update the parameter, and different parallel modes are available for training.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Union[Cell]) – Optimizer for updating the weights.

  • sens (numbers.Number) – The scaling number to be filled as the input of backpropagation. Default value is 1.0.

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Tensor, a tensor means the loss value, the shape of which is usually \(()\).

  • loss(Tensor): A scalar Tensor.

  • overflow(Tensor): A scalar Tensor which type is bool.

  • loss scaling value(Tensor): A scalar Tensor.

Raises

TypeError – If sens is not a number.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import boost
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> #1) Using the WithLossCell existing provide
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(Cell):
...    def __init__(self, backbone, loss_fn):
...        super(MyWithLossCell, self).__init__(auto_prefix=False)
...        self._backbone = backbone
...        self._loss_fn = loss_fn
...
...    def construct(self, x, y, label):
...        out = self._backbone(x, y)
...        return self._loss_fn(out, label)
...
...    @property
...    def backbone_network(self):
...        return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
adasum_process(loss, grads)[source]

Adasum algorithm process.

Parameters
  • loss (Tensor) – Tensor with shape \(()\).

  • grads (tuple(Tensor)) – Tuple of gradient tensors.

Returns

  • loss (Tensor) - Network loss, tensor with shape \(()\).

check_adasum_enable()[source]

Check adasum enable.

Returns

  • enable_adasum (bool) - Check whether the Adasum algorithm is enabled.

check_dim_reduce_enable()[source]

Check dim_reduce enable.

Returns

  • enable_dim_reduce (bool) - Check whether the dimensionality reduction second-order training algorithm is enabled.

gradient_accumulation_process(loss, grads, sens, *inputs)[source]

Gradient accumulation algorithm process.

Parameters
  • loss (Tensor) – Tensor with shape \(()\).

  • grads (tuple(Tensor)) – Tuple of gradient tensors.

  • sens (Tensor) – Tensor with shape \(()\).

  • inputs (tuple(Tensor)) – Tuple of input tensors with shape \((N, \ldots)\).

Returns

  • loss (Tensor) - Network loss, tensor with shape \(()\).

gradient_freeze_process(*inputs)[source]

Gradient freeze algorithm process.

Parameters

inputs (tuple(Tensor)) – Tuple of input tensors with shape \((N, \ldots)\).

Returns

  • loss (Tensor) - Network loss, tensor with shape \(()\).

class mindspore.boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]

Boost Network training with loss scaling.

This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update Cell as args. The loss scale value can be updated in both host side or device side. The BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes *inputs as input data. The Tensor type of scale_sense is acting as loss scaling value. If you want to update it on host side, the value must be provided. If the Tensor type of scale_sense is not given, the loss scale update logic must be provide by Cell type of scale_sense.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Cell) – Optimizer for updating the weights.

  • scale_sense (Union[Tensor, Cell]) – If this value is Cell type, the loss scaling update logic cell.If this value is Tensor type, Tensor with shape \(()\) or \((1,)\).

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.

  • loss (Tensor) - Tensor with shape \(()\).

  • overflow (Tensor) - Tensor with shape \(()\), type is bool.

  • loss scaling value (Tensor) - Tensor with shape \(()\)

Raises
  • TypeError – If scale_sense is neither Cell nor Tensor.

  • ValueError – If shape of scale_sense is neither \((1,)\) nor \(()\).

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore import dtype as mstype
>>> from mindspore import boost
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mstype.float32)
>>> labels = Tensor(np.ones([out_features,]), mstype.float32)
>>> output = train_network(input, labels)
>>>
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
class mindspore.boost.DimReduce(network, optimizer, weight, pca_mat_local, n_components, rho, gamma, alpha, sigma, rank, rank_size)[source]

The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.

\[\begin{split}\begin{align} grad\_k &= pca\_mat \cdot grad\\ dk &= - bk \cdot grad\_k\\ sk &= rho ^ m \cdot dk\\ delta\_loss &= sigma \cdot grad\_k.T \cdot sk \end{align}\end{split}\]

Here:

  • pca_mat (array): Shape \((k*n)\), k is part of n_components, n is the size of weight.

  • bk (array): Shape \((k*k)\), is the symmetric positive definite matrix in Quasi-Newton method.

we need to find the m satisfy:

\[new\_loss < old\_loss + delta\_loss\]

Then, get delta_grad to update the weights for model:

\[\begin{split}\begin{align} grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\ new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\ delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk \end{align}\end{split}\]
Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Union[Cell]) – Optimizer for updating the weights.

  • weight (Tuple(Parameter)) – Tuple of parameters.

  • pca_mat_local (numpy.ndarray) – For PCA operation, k*n, k is part of n_components, n is the size of weight.

  • n_components (int) – PCA.components.

  • rho (float) – Coefficient.

  • gamma (float) – Coefficient.

  • alpha (float) – Coefficient.

  • sigma (float) – Coefficient.

  • rank (int) – Rank number.

  • rank_size (int) – Rank size.

Inputs:
  • loss (Tensor) - Tensor with shape \(()\).

  • old_grad (Tuple(Tensor)) - Tuple of gradient tensors.

  • weight (Tuple(Tensor)) - Tuple of parameters.

  • weight_clone (Tuple(Tensor)) - clone of weight

  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:
  • loss (Tensor) - Tensor with shape \(()\).

class mindspore.boost.FreezeOpt(opt, train_parameter_groups=None, train_strategy=None)[source]

Optimizer that supports gradients freezing training.

Parameters
  • opt (Cell) – non-freezing optimizer instance, such as ‘Momentum’, ‘SGD’.

  • train_parameter_groups (Union[tuple, list]) – Groups of parameters for gradients freezing training.

  • train_strategy (Union[tuple(int), list(int), Tensor]) – Strategy for gradients freezing training.

Supported Platforms:

Ascend

class mindspore.boost.GradientAccumulation(max_accumulation_step, optimizer)[source]

After accumulating the gradients of multiple steps, call to optimize its update.

Parameters
  • max_accumulation_step (int) – Steps to accumulate gradients.

  • optimizer (Cell) – Optimizer used.

class mindspore.boost.GradientFreeze(param_groups, freeze_type, freeze_p, total_steps)[source]

Gradients freezing algorithm, freezing the gradients of some layers randomly, to improve network training performance. The number and probability of frozen layers can be configured by users.

Parameters
  • param_groups (Union[tuple, list]) – Groups of parameters for gradients freezing training.

  • freeze_type (int) – Strategy of gradients freezing training.

  • freeze_p (float) – probability of gradients freezing training.

  • total_steps (int) – Steps of the whole training.

Examples

>>> gradient_freeze_class = boost.GradientFreeze(10, 1, 0.5, 2000)
>>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer)
freeze_generate(network, optimizer)[source]

Generate freeze network and optimizer.

Parameters
  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, total_steps)[source]

Generate index sequence for gradient freezing training.

Parameters
  • parameter_groups_number (int) – The number of parameter groups.

  • freeze_strategy (int) – Gradient freeze grouping strategy, select from [0, 1].

  • freeze_p (float) – Gradient freezing probability.

  • total_steps (int) – Total training steps.

split_parameters_groups(net, freeze_para_groups_number)[source]

Split parameter groups for gradients freezing training.

Parameters
  • net (Cell) – The training network.

  • freeze_para_groups_number (int) – The number of gradient freeze groups.

class mindspore.boost.GroupLossScaleManager(init_loss_scale, loss_scale_groups)[source]

Enhanced hybrid precision algorithm supports multi-layer application of different loss scales and dynamic updating of loss scales.

Parameters
  • init_loss_scale (Number) – The initialized loss scale value.

  • loss_scale_groups (List) – The loss scale groups, which are divided from the param list.

Inputs:
  • x (Tensor) - The output of last operator.

  • layer1 (Int) - Current network layer value.

  • layer2 (Int) - Last network layer value.

Outputs:
  • x (Tensor) - The output of _DynamicLossScale operator.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import boost, nn
>>>
>>> class Net(nn.Cell):
...     def __init__(self, enhanced_amp, num_class=10, num_channel=1):
...         super(Net, self).__init__()
...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.flatten = nn.Flatten()
...         self.enhanced_amp = enhanced_amp
...
...     def construct(self, x):
...         x = self.enhanced_amp(x, 0, 1)
...         x = self.max_pool2d(self.relu(self.conv1(x)))
...         x = self.max_pool2d(self.relu(self.conv2(x)))
...         x = self.flatten(x)
...         x = self.enhanced_amp(x, 1, 2)
...         x = self.relu(self.fc1(x))
...         x = self.relu(self.fc2(x))
...         x = self.fc3(x)
...         x = self.enhanced_amp(x, 2, 3)
...         return x
>>>
>>> loss_scale_manager = boost.GroupLossScaleManager(4096, [])
>>> net = Net(loss_scale_manager)
>>> param_group1 = []
>>> param_group2 = []
>>> for param in net.trainable_params():
>>>     if 'conv' in param.name:
>>>         param_group1.append(param)
>>>     else:
>>>         param_group2.append(param)
>>> loss_scale_manager.loss_scale_groups = [param_group1, param_group2]
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> boost_config_dict = {"boost": {"mode": "manual", "less_bn": False, "grad_freeze": False, "adasum": False,         >>>                      "grad_accumulation": False, "dim_reduce": False, "loss_scale_group": True}}
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager,         >>>               boost_level="O1", boost_config_dict=boost_config_dict)
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> model.train(2, dataset)
get_loss_scale()[source]

Get loss scale value.

Returns

bool, loss_scale value.

get_update_cell()[source]

Returns the instance of mindspore.boost.GroupLossScaleManager.

Returns

mindspore.boost.GroupLossScaleManager.

set_loss_scale_status(loss_scale_number, init_loss_scale)[source]

Generate dynamic loss scale tuple and set overflow status list.

Parameters
  • loss_scale_number (int) – The number of loss scale.

  • init_loss_scale (float) – The initialized loss scale.

update_loss_scale_status(layer, update_ratio)[source]

Update dynamic loss scale.

Parameters
  • layer (int) – Current layer.

  • update_ratio (float) – The ratio of loss scale update.

Outputs:

float, new loss scale.

class mindspore.boost.LessBN(network, fn_flag=False)[source]

Reduce the number of BN automatically to improve the network performance and ensure the network accuracy.

Parameters
  • network (Cell) – Network to be modified.

  • fn_flag (bool) – Replace FC with FN. default: False.

Examples

>>> network = boost.LessBN(network)
class mindspore.boost.OptimizerProcess(opt)[source]

Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags and creating new optimizers.

Parameters

opt (Cell) – Optimizer used.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.boost import OptimizerProcess
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optimizer_process = OptimizerProcess(optimizer)
>>> optimizer_process.add_grad_centralization(network)
>>> optimizer = optimizer_process.generate_new_optimizer()
add_grad_centralization(network)[source]

Add gradient centralization.

Parameters

network (Cell) – The training network.

static build_gc_params_group(params_dict, parameters)[source]

Build the parameter’s group with grad centralization.

Parameters
  • params_dict (dict) – The network’s parameter dict.

  • parameters (list) – The network’s parameter list.

static build_params_dict(network)[source]

Build the parameter’s dict of the network.

Parameters

network (Cell) – The training network.

generate_new_optimizer()[source]

Generate new optimizer.

class mindspore.boost.ParameterProcess[source]

Process parameter for Boost. Currently, this class supports creating group parameters and automatically setting gradient segmentation point.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.boost import ParameterProcess
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight2')
...         self.matmul = ops.MatMul()
...         self.matmul2 = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         output2 = self.matmul2(x, self.weight2)
...         return output + output2
...
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> new_parameter = network.trainable_params()[:1]
>>> group_params = ParameterProcess.generate_group_params(new_parameter, network.trainable_params())
assign_parameter_group(parameters, split_point=None)[source]

Assign parameter group.

Parameters
  • parameters (list) – The network’s parameter list.

  • split_point (list) – The gradient split point of this network. default: None.

static generate_group_params(parameters, origin_params)[source]

Generate group parameters.

Parameters
  • parameters (list) – The network’s parameter list.

  • origin_params (list) – The network’s origin parameter list.

mindspore.boost.freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None, max_accumulation_step=1)[source]

Generate freeze network and optimizer.

Parameters
  • reducer_flag (bool) – Reducer flag.

  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

  • sens (numbers.Number) – The scaling number.

  • grad (tuple(Tensor)) – Tuple of gradient tensors.

  • use_grad_accumulation (bool) – Use gradient accumulation flag.

  • mean (bool) – Gradients mean flag. default: None.

  • degree (int) – Device number. default: None.

  • max_accumulation_step (int) – Max accumulation steps. default: 1.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.boost.grad_freeze import freeze_cell, FreezeOpt
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> in_features, out_features = 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optimizer = FreezeOpt(optimizer)
>>> grad = ops.GradOperation(get_by_list=True, sens_param=True)
>>> freeze_nets = freeze_cell(False, network, optimizer, 1.0, grad, False, None, None, 1)