mindspore_gs

MindSpore golden stick module.

class mindspore_gs.PrunerFtCompressAlgo(config)[source]

Derived class of GoldenStick. Scop-algorithm. FineTune for recover net.

Parameters

config (Dict) – Configuration of PrunerFtCompressAlgo. There are no configurable options for PrunerFtCompressAlgo currently, but for compatibility, the config parameter in the constructor of class A is retained.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore_gs import PrunerKfCompressAlgo, PrunerFtCompressAlgo
>>> from models.resnet import resnet50
>>> class NetToPrune(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToPrune, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> net = NetToPrune()
>>> kf_pruning = PrunerKfCompressAlgo()
>>> net_pruning_kf = kf_pruning.apply(net)
>>> ## 1) Define FineTune Algorithm
>>> ft_pruning = PrunerFtCompressAlgo()
>>> ## 2) Apply FineTune-algorithm to origin network
>>> net_pruning_ft = ft_pruning.apply(net_pruning_kf)
>>> ## 3) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning_ft)
NetToPrune<
  (conv): MaskedConv2dbn<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
      padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
      (name=conv.bn.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
      (name=conv.bn.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
      (name=conv.bn.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
      (name=conv.bn.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
    >
  (bn): SequentialCell<>
  >
apply(network)[source]

Transform a knockoff network to a normal and pruned network.

Parameters

network (Cell) – Knockoff network.

Returns

Pruned network.

class mindspore_gs.PrunerKfCompressAlgo(config)[source]

Derived class of GoldenStick. Scop-algorithm. Construct effective knockoff counterparts.

Parameters

config (Dict) – Configuration of PrunerKfCompressAlgo. There are no configurable options for PrunerKfCompressAlgo currently, but for compatibility, the config parameter in the constructor of class A is retained.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore_gs import PrunerKfCompressAlgo
>>> from models.resnet import resnet50
>>> class NetToPrune(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToPrune, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToPrune()
>>> ## 2) Define Knockoff Algorithm
>>> kf_pruning = PrunerKfCompressAlgo()
>>> ## 3) Apply Konckoff-algorithm to origin network
>>> net_pruning = kf_pruning.apply(net)
>>> ## 4) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning)
NetToPrune<
  (conv): KfConv2d<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
      padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
      (name=conv.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
      (name=conv.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
      (name=conv.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
      (name=conv.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
    >
  (bn): SequentialCell<>
  >
apply(network)[source]

Transform input network to a knockoff network.

Parameters

network (Cell) – Network to be pruned.

Returns

Knockoff network.

class mindspore_gs.SimulatedQuantizationAwareTraining(config=None)[source]

Derived class of GoldenStick. Simulated QAT-algorithm.

Parameters

config (dict) –

store attributes for quantization aware training, keys are attribute names, values are attribute values. supported attribute are listed below:

  • quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during train and eval. The first element represents data flow and the second element represents weights. Default: (0, 0).

  • quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first element represents data flow and the second element represents weights. It is necessary to consider the precision support of hardware devices in the practical quantization infer scenaries. Default: (QuantDtype.INT8, QuantDtype.INT8).

  • per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If True then base on per channel, otherwise base on per layer. The first element represents data flow and the second element represents weights, and the first element must be False now. Default: (False, False).

  • symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If True then base on symmetric, otherwise base on asymmetric. The first element represents data flow and the second element represents weights. Default: (False, False).

  • narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. The first element represents data flow and the second element represents weights. Default: (False, False).

  • enable_fusion (bool): Whether apply fusion before applying quantization. Default: False.

  • freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. Default: 10000000.

  • bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: False.

  • one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True.

Raises
  • TypeError – If the element of quant_delay is not int.

  • TypeError – If the element of per_channel, symmetric, narrow_range, bn_fold, one_conv_fold is not bool.

  • TypeError – If the element of quant_dtype is not QuantDtype.

  • TypeError – If freeze_bn is not int.

  • ValueErrorfreeze_bn is less than 0.

  • ValueError – If the length of quant_delay, quant_dtype, per_channel, symmetric or narrow_range is not less than 2.

  • ValueError – If the element of quant_delay is less than 0.

  • ValueError – If the first element of per_channel is True.

  • NotImplementedError – If the element of quant_dtype is not QuantDtype.INT8.

Supported Platforms:

GPU

Examples

>>> from mindspore_gs.quantization.simulated_quantization import SimulatedQuantizationAwareTraining
>>> from mindspore import nn
... class NetToQuant(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToQuant, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToQuant()
>>> ## 2) Define SimQAT Algorithm
>>> simulated_quantization = SimulatedQuantizationAwareTraining()
>>> ## 3) Use set functions to change config
>>> simulated_quantization.set_enable_fusion(True)
>>> simulated_quantization.set_bn_fold(False)
>>> simulated_quantization.set_act_quant_delay(900)
>>> simulated_quantization.set_weight_quant_delay(900)
>>> simulated_quantization.set_act_per_channel(False)
>>> simulated_quantization.set_weight_per_channel(True)
>>> simulated_quantization.set_act_narrow_range(False)
>>> simulated_quantization.set_weight_narrow_range(False)
>>> ## 4) Apply SimQAT algorithm to origin network
>>> net_qat = simulated_quantization.apply(net)
>>> ## 5) Print network and check the result. Conv2d and Dense should be transformed to QuantizeWrapperCells.
>>> ## Since we set enable_fusion to be True, bn_fold to be False, the Conv2d and BatchNorm2d Cells are
>>> ## fused and converted to Conv2dBnWithoutFoldQuant.
>>> ## Since we set act_quant_delay to be 900, the quant_delay value of _input_quantizer and _output_quantizer
>>> ## are set to be 900.
>>> ## Since we set weight_quant_delay to be 900, the quant_delay value of fake_quant_weight are set to be 900.
>>> ## Since we set act_per_channel to be False, the per_channel value of _input_quantizer and
>>> ## _output_quantizer are set to be False.
>>> ## Since we set weight_per_channel to be True, the per_channel value of fake_quant_weight are set to be
>>> ## True.
>>> ## Since we set act_narrow_range to be False, the narrow_range value of _input_quantizer and
>>> ## _output_quantizer are set to be False.
>>> ## Since we set weight_narrow_range to be False, the narrow_range value of fake_quant_weight are set to be
>>> ## True.
>>> print(net_qat)
NetToQuantOpt<
  (_handler): NetToQuant<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=_handler.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=_handler.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=_handler.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=_handler.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
    >
  (Conv2dBnWithoutFoldQuant): QuantizeWrapperCell<
    handler: in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, input quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900, output quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900
    (_handler): Conv2dBnWithoutFoldQuant<
      in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 6), quant_delay=900>
      (batchnorm): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.0030000000000000027, gamma=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  >
apply(network: Cell)[source]

Apply SimQAT Algorithm on network, use the following steps to make network available for quantization aware training:

  1. Fuse certain cells in network using pattern engine which is defined by net policy. Default fuse pattern: Conv2d + BatchNorm2d + ReLU, Conv2d + ReLU, Dense + BatchNorm2d + ReLU, Dense + BatchNorm2d, Dense + ReLU.

  2. Propagate LayerPolicies defined in NetPolicy through network.

  3. Reduce redundant fake quantizers which means two or more fake quantizers existing on one tensor.

  4. Apply LayerPolicies to convert normal cell to QuantizeWrapperCell. We will insert real fake quantizer into network in this step.

Parameters

network (Cell) – Network to be quantized.

Returns

Quantized network.

set_act_narrow_range(act_narrow_range)[source]

Set value of act_narrow_range of _config

Parameters

act_narrow_range (bool) – Whether the quantization algorithm use act narrow_range or not. If True then base on narrow_range, otherwise base on not narrow_range.

Raises

TypeError – If act_narrow_range is not bool.

set_act_per_channel(act_per_channel)[source]

Set value of act_per_channel of _config

Parameters

act_per_channel (bool) – Quantization granularity based on layer or on channel. If True then base on per channel, otherwise base on per layer. Only support False now.

Raises
set_act_quant_delay(act_quant_delay)[source]

Set value of act_quant_delay of _config

Parameters

act_quant_delay (int) – Number of steps after which activation is quantized during train and eval.

Raises
set_act_quant_dtype(act_quant_dtype)[source]

Set value of act_quant_dtype of _config

Parameters

act_quant_dtype (QuantDtype) – Datatype used to quantize activations.

Raises
set_act_symmetric(act_symmetric)[source]

Set value of act_symmetric of _config

Parameters

act_symmetric (bool) – Whether the quantization algorithm use act symmetric or not. If True then base on symmetric, otherwise base on asymmetric.

Raises

TypeError – If act_symmetric is not bool.

set_bn_fold(bn_fold)[source]

Set value of bn_fold of _config

Parameters

bn_fold (bool) – Whether quantization algorithm use bn_fold or not.

Raises

TypeError – If bn_fold is not bool.

set_enable_fusion(enable_fusion)[source]

Set value of enable_fusion of _config

Parameters

enable_fusion (bool) – Whether apply fusion before applying quantization, default is False.

Raises

TypeError – If enable_fusion is not bool.

set_freeze_bn(freeze_bn)[source]

Set value of freeze_bn of _config

Parameters

freeze_bn (int) – Number of steps after which BatchNorm OP parameters fixed to global mean and variance.

Raises
set_one_conv_fold(one_conv_fold)[source]

Set value of one_conv_fold of _config

Parameters

one_conv_fold (bool) – Whether quantization algorithm use one_conv_fold or not.

Raises

TypeError – If one_conv_fold is not bool.

set_weight_narrow_range(weight_narrow_range)[source]

Set value of weight_narrow_range of _config

Parameters

weight_narrow_range (bool) – Whether the quantization algorithm use weight narrow_range or not. If True then base on narrow_range, otherwise base on not narrow_range.

Raises

TypeError – If weight_narrow_range is not bool.

set_weight_per_channel(weight_per_channel)[source]

Set value of weight_per_channel of _config

Parameters

weight_per_channel (bool) – Quantization granularity based on layer or on channel. If True then base on per channel, otherwise base on per layer.

Raises

TypeError – If weight_per_channel is not bool.

set_weight_quant_delay(weight_quant_delay)[source]

Set value of weight_quant_delay of _config

Parameters

weight_quant_delay (int) – Number of steps after which weight is quantized during train and eval.

Raises
  • TypeError – If weight_quant_delay is not int.

  • ValueError – weight_quant_delay is less than 0.

set_weight_quant_dtype(weight_quant_dtype)[source]

Set value of weight_quant_dtype of _config

Parameters

weight_quant_dtype (QuantDtype) – Datatype used to quantize activations.

Raises
  • TypeError – If weight_quant_dtype is not QuantDtype.

  • NotImplementedError – Only supported if weight_quant_dtype is QuantDtype.INT8 yet.

set_weight_symmetric(weight_symmetric)[source]

Set value of weight_symmetric of _config

Parameters

weight_symmetric (bool) – Whether the quantization algorithm use weight symmetric or not. If True then base on symmetric, otherwise base on asymmetric.

Raises

TypeError – If weight_symmetric is not bool.

class mindspore_gs.SlbQuantAwareTraining(config=None)[source]

Derived class of GoldenStick. SLB(Searching for Low-Bit Weights) QAT-algorithm.

Parameters

config (dict) –

store attributes for quantization aware training, keys are attribute names, values are attribute values. Supported attribute are listed below:

  • quant_dtype (QuantDtype): Datatype used to quantize weights, weights quantization support int4|int2|int1 now. Default: QuantDtype.INT1.

Raises

TypeError – If quant_dtype is not QuantDtype.

Supported Platforms:

GPU

Examples

>>> from mindspore_gs.quantization.slb import SlbQuantAwareTraining
>>> from mindspore_gs.quantization.constant import QuantDtype
>>> from mindspore import nn
>>> class NetToQuant(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToQuant, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToQuant()
>>> ## 2) Define SLB QAT-Algorithm
>>> slb_quantization = SlbQuantAwareTraining()
>>> ## 3) Use set functions to change config
>>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1)
>>> ## 4) Apply SLB QAT-algorithm to origin network
>>> net_qat = slb_quantization.apply(net)
>>> ## 5) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells.
>>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the bit_num value of fake_quant_weight
>>> ## should be 1, and the weight_bit_num value of Conv2dSlbQuant should be 1.
>>> print(net_qat)
NetToQuantOpt<
  (_handler): NetToQuant<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
    >
  (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
  (Conv2dSlbQuant): QuantizeWrapperCell<
    (_handler): Conv2dSlbQuant<
      in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False
      (fake_quant_weight): SlbFakeQuantizerPerLayer<bit_num=1>
      >
    >
  >
apply(network: Cell)[source]

Apply SLB quantization Algorithm on network, use the following steps to make network available for quantization aware training:

  1. Fuse certain cells in network using pattern engine which is defined by net policy.

  2. Propagate layer policies defined through cells.

  3. Reduce redundant fake quantizers when they are redundant.

  4. Apply layer policies to convert normal cell to QuantizeWrapperCell.

Parameters

network (Cell) – Network to be quantized.

Returns

Quantized network.

set_weight_quant_dtype(weight_quant_dtype)[source]

Set value of weight_quant_dtype of _config

Parameters

weight_quant_dtype (QuantDtype) – Datatype used to quantize weights. Default: QuantDtype.INT1.

Raises
  • TypeError – If weight_quant_dtype is not QuantDtype.

  • NotImplementedError – Only supported if weight_quant_dtype is QuantDtype.INT1, QuantDtype.INT2 or QuantDtype.INT4 yet.