mindspore_gs.quantization.SlbQuantAwareTraining

class mindspore_gs.quantization.SlbQuantAwareTraining(config=None)[source]

Implementation of slb quantization algorithm, this algorithm regards the discrete weights in an arbitrary quantized neural network as searchable variables, and utilize a differential method to search them accurately. In particular, each weight is represented as a probability distribution over the discrete value set. The probabilities are optimized during training and the values with the highest probability are selected to establish the desired quantized network. See more details in Searching for Low-Bit Weights in Quantized Neural Networks.

Note

This method will call other set functions to set special values, please refer to the set function about the error. For example, quant_dtype need refer to set_weight_quant_dtype and set_act_quant_dtype.

Parameters

config (dict) –

store attributes for quantization aware training, keys are attribute names, values are attribute values. Default: None. Supported attribute are listed below:

  • quant_dtype (Union[QuantDtype, list(QuantDtype), tuple(QuantDtype)]): Datatype used to quantize weights and activations. The type is a QuantDtype, a list of two QuantDtype or a tuple of two QuantDtype. If quant_dtype is a QuantDtype, it will be duplicated to a list of two QuantDtype. The first element represents the type of activations and the second element represents the type of weights. It is necessary to consider the precision support of hardware devices in the practical quantization infer scenaries. Weights quantization support int4|int2|int1, and activations quantization support int8 now. Default: (QuantDtype.INT8, QuantDtype.INT1).

  • enable_act_quant (bool): Whether apply activation quantization while training. Default: False.

  • enable_bn_calibration (bool): Whether apply batchnorm calibration while training. Default: False.

  • epoch_size (int): Total training epochs.

  • has_trained_epoch (int): The trained epochs.

  • t_start_val (float): Initial value of temperature hyperparameters. Default: 1.

  • t_start_time (float): Fraction of epochs after which temperature hyperparameters starting changing. Default: 0.2.

  • t_end_time (float): Fraction of epochs after which temperature hyperparameters stopping changing. Default: 0.6.

  • t_factor (float): Multiplicative factor of temperature hyperparameters changing. Default: 1.2.

Raises
  • TypeError – If quant_dtype is not QuantDtype, or every element of quant_dtype is not QuantDtype.

  • TypeError – If enable_act_quant or enable_bn_calibration is not bool.

  • ValueError – If the length of quant_dtype is greater than 2.

  • TypeError – If epoch_size or has_trained_epoch is not an int.

  • TypeError – If t_start_val, t_start_time, t_end_time or t_factor is not float.

  • ValueError – If epoch_size is not greater than 0.

  • ValueError – If has_trained_epoch is less than 0.

  • ValueError – If t_start_val or t_factor is not greater than 0.

  • ValueError – If t_start_time or t_end_time is less than 0.

  • ValueError – If t_start_time or t_end_time is greater than 1.

Supported Platforms:

GPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore_gs.quantization import SlbQuantAwareTraining
>>> from mindspore.common.dtype import QuantDtype
>>> class NetToQuant(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToQuant, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToQuant()
>>> ## 2) Define SLB QAT-Algorithm
>>> slb_quantization = SlbQuantAwareTraining()
>>> ## 3) Use set functions to change config
>>> ## 3.1) set_weight_quant_dtype is used to set the weight quantization bit, and support QuantDtype.INT4, QuantDtype.INT2,
>>> ## QuantDtype.INT1 now.
>>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1)
>>> ## 3.2) set_act_quant_dtype is used to set the activation quantization bit, and support QuantDtype.INT8 now.
>>> slb_quantization.set_act_quant_dtype(QuantDtype.INT8)
>>> ## 3.3) set_enable_act_quant is used to set whether apply activation quantization.
>>> slb_quantization.set_enable_act_quant(True)
>>> ## 3.4) set_enable_bn_calibration is used to set whether apply batchnorm calibration.
>>> slb_quantization.set_enable_bn_calibration(True)
>>> ## 3.5) set_epoch_size is used to set the epoch size of training.
>>> slb_quantization.set_epoch_size(100)
>>> ## 3.6) set_has_trained_epoch is used to set the trained epoch size of training.
>>> slb_quantization.set_has_trained_epoch(0)
>>> ## 3.7) set_t_start_val is used to set the initial value of temperature hyperparameters.
>>> slb_quantization.set_t_start_val(1.0)
>>> ## 3.8) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing.
>>> slb_quantization.set_t_start_time(0.2)
>>> ## 3.9) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing.
>>> slb_quantization.set_t_end_time(0.6)
>>> ## 3.10) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing.
>>> slb_quantization.set_t_factor(1.2)
>>> ## 4) Print SLB QAT-Algorithm object and check the config setting result
>>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the value of the attribute weight_quant_dtype is INT1
>>> ## Since we set act_quant_dtype to be QuantDtype.INT8, the value of the attribute weight_quant_dtype is INT8
>>> ## Since we set enable_act_quant to be True, the value of the attribute enable_act_quant is True
>>> ## Since we set enable_bn_calibration to be True, the value of the attribute enable_bn_calibration is True
>>> ## Since we set epoch_size to be 100, the value of the attribute epoch_size is 100
>>> ## Since we set has_trained_epoch to be 0, the value of the attribute has_trained_epoch is 0
>>> ## Since we set t_start_val to be 1.0, the value of the attribute t_start_val is 1.0
>>> ## Since we set t_start_time to be 0.2, the value of the attribute t_start_time is 0.2
>>> ## Since we set t_end_time to be 0.6, the value of the attribute t_end_time is 0.6
>>> ## Since we set t_factor to be 1.2, the value of the attribute t_factor is 1.2
>>> print(slb_quantization)
SlbQuantAwareTraining<weight_quant_dtype=INT1, act_quant_dtype=INT8, enable_act_quant=True, enable_bn_calibration=True, epoch_size=100, has_trained_epoch=0, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2>
>>> ## 5) Apply SLB QAT-algorithm to origin network
>>> net_qat = slb_quantization.apply(net)
>>> ## 6) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells.
>>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the bit_num value of fake_quant_weight
>>> ## should be 1, and the weight_bit_num value of Conv2dSlbQuant should be 1.
>>> print(net_qat)
NetToQuantOpt<
  (_handler): NetToQuant<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])>
    >
  (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])>
  (Conv2dSlbQuant): QuantizeWrapperCell<
    (_handler): Conv2dSlbQuant<
      in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False
      (fake_quant_weight): SlbFakeQuantizerPerLayer<bit_num=1>
      >
    (_input_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  >
>>> ## 7) convert a compressed network to a standard network before exporting to MindIR.
>>> net_qat = slb_quantization.convert(net_qat)
>>> data_in = mindspore.Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
>>> file_name = "./conv.mindir"
>>> mindspore.export(net_qat, data_in, file_name=file_name, file_format="MINDIR")
>>> graph = mindspore.load(file_name)
>>> mindspore.nn.GraphCell(graph)
apply(network: Cell)[source]

Apply SLB quantization Algorithm on network, use the following steps to make network available for quantization aware training:

  1. Fuse certain cells in network using pattern engine which is defined by net policy.

  2. Propagate layer policies defined through cells.

  3. Reduce redundant fake quantizers when they are redundant.

  4. Apply layer policies to convert normal cell to QuantizeWrapperCell.

Parameters

network (Cell) – Network to be quantized.

Returns

Quantized network.

callbacks(model: Model, dataset: Dataset)[source]

Define TemperatureScheduler callback for SLB QAT-algorithm.

Parameters
  • model (Model) – Model to be used.

  • dataset (Dataset) – Dataset to be used.

Returns

List of instance of Callbacks.

Raises
  • RuntimeError – If epoch_size is not initialized!

  • RuntimeError – If has_trained_epoch is not initialized!

  • ValueError – If epoch_size is not greater than has_trained_epoch.

  • ValueError – If t_end_time is less than t_start_time.

  • TypeError – If model is not mindspore.Model.

  • TypeError – If dataset is not mindspore.dataset.Dataset.

convert(net_opt: Cell, ckpt_path='')[source]

Define how to convert a compressed network to a standard network before exporting to MindIR.

Parameters
  • net_opt (Cell) – Network to be converted which is transformed by SlbQuantAwareTraining.apply.

  • ckpt_path (str) – Path to checkpoint file for net_opt. Default is a empty string which means not loading checkpoint file to net_opt.

Returns

An instance of Cell represents converted network.

Raises
  • TypeError – If net_opt is not Cell.

  • TypeError – If ckpt_path is not string.

  • ValueError – If ckpt_path is not empty and invalid.

  • RuntimeError – If ckpt_path is a valid file and load checkpoint file failed.

set_act_quant_dtype(act_quant_dtype=QuantDtype.INT8)[source]

Set value of act_quant_dtype of quantization aware training config

Parameters

act_quant_dtype (QuantDtype) – Datatype used to quantize activations. Default: QuantDtype.INT8.

Raises
  • TypeError – If act_quant_dtype is not QuantDtype.

  • ValueError – Only supported if act_quant_dtype is QuantDtype.INT8 yet.

set_enable_act_quant(enable_act_quant=False)[source]

Set value of enable_act_quant of quantization aware training config

Parameters

enable_act_quant (bool) – Whether apply activation quantization while training, default is False.

Raises

TypeError – If enable_act_quant is not bool.

set_enable_bn_calibration(enable_bn_calibration=False)[source]

Set value of enable_bn_calibration of quantization aware training config

Parameters

enable_bn_calibration (bool) – Whether apply batchnorm calibration while training, default is False.

Raises

TypeError – If enable_bn_calibration is not bool.

set_epoch_size(epoch_size)[source]

Set value of epoch_size of quantization aware training config

Parameters

epoch_size (int) – the epoch size of training.

Raises
set_has_trained_epoch(has_trained_epoch)[source]

Set value of has_trained_epoch of quantization aware training config

Parameters

has_trained_epoch (int) – the trained epochs of training.

Raises
  • TypeError – If has_trained_epoch is not int.

  • ValueError – If has_trained_epoch is less than 0.

set_t_end_time(t_end_time=0.6)[source]

Set value of t_end_time of quantization aware training config

Parameters

t_end_time (float) – Fraction of epochs after which temperature hyperparameters stopping changing, default: 0.6.

Raises
  • TypeError – If t_end_time is not float.

  • ValueError – If t_end_time is less than 0. or greater than 1.

set_t_factor(t_factor=1.2)[source]

Set value of t_factor of quantization aware training config

Parameters

t_factor (float) – Multiplicative factor of temperature hyperparameters changing, default: 1.2.

Raises
set_t_start_time(t_start_time=0.2)[source]

Set value of t_start_time of quantization aware training config

Parameters

t_start_time (float) – Fraction of epochs after which temperature hyperparameters starting changing, default: 0.2.

Raises
  • TypeError – If t_start_time is not float.

  • ValueError – If t_start_time is less than 0. or greater than 1.

set_t_start_val(t_start_val=1.0)[source]

Set value of t_start_val of quantization aware training config

Parameters

t_start_val (float) – Initial value of temperature hyperparameters, default: 1.0.

Raises
  • TypeError – If t_start_val is not float.

  • ValueError – If t_start_val is not greater than 0.

set_weight_quant_dtype(weight_quant_dtype=QuantDtype.INT1)[source]

Set value of weight_quant_dtype of quantization aware training config

Parameters

weight_quant_dtype (QuantDtype) – Datatype used to quantize weights. Default: QuantDtype.INT1.

Raises
  • TypeError – If weight_quant_dtype is not QuantDtype.

  • ValueError – Only supported if weight_quant_dtype is QuantDtype.INT1, QuantDtype.INT2 or QuantDtype.INT4 yet.