mindspore_gs.quantization.SimulatedQuantizationAwareTraining

class mindspore_gs.quantization.SimulatedQuantizationAwareTraining(config=None)[源代码]

模拟量化感知训练的基本实现,该算法在训练时使用伪量化节点来模拟量化计算的损失,并通过反向传播更新网络参数,使得网络参数更好地适应量化带来的损失。更多详细信息见 神经网络量化白皮书

参数:
  • config (dict) - 存储用于量化感知训练的属性,键是属性名称,值是属性值。默认值为None,下面列出了支持的属性:

    • quant_delay (Union[int, list, tuple]) - 在训练和评估期间权重和激活量化后的步数。第一个元素表示激活,第二个元素表示权重。默认值:(0, 0)。

    • quant_dtype (Union[QuantDtype, list, tuple]) - 用于指定量化的目标数据类型。在设置`quant_dtype`时,必须考虑硬件设备的精度支持。第一个元素表示激活,第二个元素表示权重。默认值:(QuantDtype.INT8, QuantDtype.INT8)。

    • per_channel (Union[bool, list, tuple]) - 基于层或通道的量化粒度。如果为True,则基于每个通道,否则基于每个层。第一个元素表示激活,第二个元素表示权重,第一个元素现在必须为 False。默认值:(False, False)。

    • symmetric (Union[bool, list, tuple]) - 量化算法是否对称。如果为True,则基于对称,否则基于不对称。第一个元素表示激活,第二个元素表示权重。默认值:(False, False)。

    • narrow_range (Union[bool, list, tuple]) - 量化算法是否使用窄范围。第一个元素表示激活,第二个元素表示权重。默认值:(False, False)。

    • enable_fusion (bool) - 在应用量化之前是否应用融合。默认值:False。

    • freeze_bn (int) - BatchNorm OP 参数固定为全局均值和方差之后的步数。默认值:10000000。

    • bn_fold (bool) - 是否使用 bn fold 算子进行模拟推理操作。默认值:False。

    • one_conv_fold (bool) - 是否使用 one conv bn fold 算子进行模拟推理操作。默认值:True。

异常:
  • TypeError - bn_foldone_conv_fold 或者 enable_fusion 的元素类型不是bool。

  • TypeError - freeze_bn 的数据类型不是int。

  • TypeError - quant_delay 的数据类型不是int,或者 quant_delay 存在不是int的元素。

  • TypeError - quant_dtype 的数据类型不是 QuantDtype ,或者 quant_dtype 存在不是 QuantDtype 的元素。

  • TypeError - per_channel 的数据类型不是bool,或者 per_channel 存在不是bool的元素。

  • TypeError - symmetric 的数据类型不是bool,或者 symmetric 存在不是bool的元素。

  • TypeError - narrow_range 的数据类型不是bool,或者 narrow_range 存在不是bool的元素。

  • ValueError - freeze_bn 小于0。

  • ValueError - quant_delayquant_dtypeper_channelsymmetric 或者 narrow_range 的长度大于2。

  • ValueError - quant_delay 小于0,或者 quant_delay 存在小于0的元素。

  • ValueError - quant_dtype 的数据类型不是 QuantDtype.INT8 或者 quant_dtype 存在不是 QuantDtype.INT8 的元素。

  • ValueError - per_channel 为True, 或者 per_channel 的第一个元素为True。

支持平台:

GPU

样例:

>>> from mindspore_gs.quantization import SimulatedQuantizationAwareTraining
>>> from mindspore import nn
... class NetToQuant(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToQuant, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToQuant()
>>> ## 2) Define SimQAT Algorithm
>>> simulated_quantization = SimulatedQuantizationAwareTraining()
>>> ## 3) Use set functions to change config
>>> simulated_quantization.set_enable_fusion(True)
>>> simulated_quantization.set_bn_fold(False)
>>> simulated_quantization.set_act_quant_delay(900)
>>> simulated_quantization.set_weight_quant_delay(900)
>>> simulated_quantization.set_act_per_channel(False)
>>> simulated_quantization.set_weight_per_channel(True)
>>> simulated_quantization.set_act_narrow_range(False)
>>> simulated_quantization.set_weight_narrow_range(False)
>>> ## 4) Apply SimQAT algorithm to origin network
>>> net_qat = simulated_quantization.apply(net)
>>> ## 5) Print network and check the result. Conv2d and Dense should be transformed to QuantizeWrapperCells.
>>> ## Since we set enable_fusion to be True, bn_fold to be False, the Conv2d and BatchNorm2d Cells are
>>> ## fused and converted to Conv2dBnWithoutFoldQuant.
>>> ## Since we set act_quant_delay to be 900, the quant_delay value of _input_quantizer and _output_quantizer
>>> ## are set to be 900.
>>> ## Since we set weight_quant_delay to be 900, the quant_delay value of fake_quant_weight are set to be 900.
>>> ## Since we set act_per_channel to be False, the per_channel value of _input_quantizer and
>>> ## _output_quantizer are set to be False.
>>> ## Since we set weight_per_channel to be True, the per_channel value of fake_quant_weight are set to be
>>> ## True.
>>> ## Since we set act_narrow_range to be False, the narrow_range value of _input_quantizer and
>>> ## _output_quantizer are set to be False.
>>> ## Since we set weight_narrow_range to be False, the narrow_range value of fake_quant_weight are set to be
>>> ## True.
>>> print(net_qat)
NetToQuantOpt<
(_handler): NetToQuant<
(conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=_handler.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=_handler.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=_handler.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=_handler.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
>
(Conv2dBnWithoutFoldQuant): QuantizeWrapperCell<
handler: in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, input quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900, output quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900
(_handler): Conv2dBnWithoutFoldQuant<
in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False
(fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 6), quant_delay=900>
(batchnorm): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.0030000000000000027, gamma=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
>
(_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
(_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
>
>
apply(network: Cell)[源代码]

按照以下步骤在 network 中应用SimQAT算法,使 network 可用于量化感知训练:

  1. 使用由网络策略定义的模式引擎融合 network 中的某些单元。默认融合模式:Conv2d + BatchNorm2d + ReLU, Conv2d + ReLU, Dense + BatchNorm2d + ReLU, Dense + BatchNorm2d, Dense + ReLU。

  2. 在网络中传播NetPolicy中定义的LayerPolices。

  3. 减少冗余的假量化器,即一个张量上存在两个或多个假量化器。

  4. 应用LayerPolicies将普通Cell转换为 QuantizeWrapperCell 。在此步骤中,我们将在网络中插入真正的假量化器。

参数:
  • network (Cell) - 待量化的网络。

返回:

量化后的网络。

convert(net_opt: Cell, ckpt_path='')[源代码]

将量化网络 net_opt 转换为标准网络,后续导出成MindIR用于部署。

参数:
  • net_opt (Cell) - 经过量化算法apply之后的网络。

  • ckpt_path (str) - 网络的checkpoint file文件路径,默认值为空,表示不加载。

异常:
  • TypeError - net_opt 数据类型不是Cell。

  • TypeError - ckpt_path 数据类型不是str。

  • ValueError - ckpt_path 非空但不是有效路径。

返回:

转换后的网络。

set_act_narrow_range(act_narrow_range)[源代码]

设置量化感知训练参数 config 的act_narrow_range值。

参数:
  • act_narrow_range (bool) - 量化算法是否使用 act_narrow_range 。如果为True,则基于narrow_range,否则不基于narrow_range。

异常:
  • TypeError - act_narrow_range 数据类型不是bool。

set_act_per_channel(act_per_channel)[源代码]

设置量化感知训练参数 config 的act_per_channel值。

参数:
  • act_per_channel (bool) - 量化算法基于层还是通道。如果为True,则基于通道,否则基于层。当前只支持False。

异常:
  • TypeError - act_per_channel 数据类型不是bool。

  • ValueError - act_per_channel 不是False。

set_act_quant_delay(act_quant_delay)[源代码]

设置量化感知训练参数 config 的act_quant_delay值。

参数:
  • act_quant_delay (int) - 在训练和评估期间激活量化后的步数。

异常:
  • TypeError - act_quant_delay 数据类型不是int。

  • ValueError - act_quant_delay 小于0。

set_act_quant_dtype(act_quant_dtype)[源代码]

设置量化感知训练参数 config 的act_quant_dtype值。

参数:
  • act_quant_dtype (QuantDtype) - 激活量化的数据类型。

异常:
  • TypeError - act_quant_dtype 数据类型不是QuantDtype。

  • ValueError - act_quant_dtype 不是 QuantDtype.INT8

set_act_symmetric(act_symmetric)[源代码]

设置量化感知训练参数 config 的act_symmetric值。

参数:
  • act_symmetric (bool) - 量化算法是否使用激活对称。如果为True,则基于对称,否则基于不对称。

异常:
  • TypeError - act_symmetric 数据类型不是bool。

set_bn_fold(bn_fold)[源代码]

设置量化感知训练参数 config 的bn_fold值。

参数:
  • bn_fold (bool) - 量化算法是否使用 bn_fold

异常:
  • TypeError - bn_fold 数据类型不是bool。

set_enable_fusion(enable_fusion)[源代码]

设置量化感知训练参数 config 的enable_fusion值。

参数:
  • enable_fusion (bool) - 是否在量化之前进行融合。

异常:
  • TypeError - enable_fusion 数据类型不是bool。

set_freeze_bn(freeze_bn)[源代码]

设置量化感知训练参数 config 的freeze_bn值。

参数:
  • freeze_bn (int) - BatchNorm OP 参数固定为全局均值和方差之后的步数。

异常:
  • TypeError - freeze_bn 数据类型不是int。

  • ValueError - freeze_bn 小于0。

set_one_conv_fold(one_conv_fold)[源代码]

设置量化感知训练参数 config 的one_conv_fold值。

参数:
  • one_conv_fold (bool) - 量化算法是否使用 one_conv_fold

异常:
  • TypeError - one_conv_fold 数据类型不是bool。

set_weight_narrow_range(weight_narrow_range)[源代码]

设置量化感知训练参数 config 的weight_narrow_range值。

参数:
  • weight_narrow_range (bool) - 量化算法是否使用权重narrow_range。如果为True,则基于narrow_range,否则不基于narrow_range。

异常:
  • TypeError - weight_narrow_range 数据类型不是bool。

set_weight_per_channel(weight_per_channel)[源代码]

设置量化感知训练参数 config 的weight_per_channel值。

参数:
  • weight_per_channel (bool) - 量化算法基于层还是通道。如果为True,则基于通道,否则基于层。

异常:
  • TypeError - weight_per_channel 数据类型不是bool。

set_weight_quant_delay(weight_quant_delay)[源代码]

设置量化感知训练参数 config 的weight_quant_delay值。

参数:
  • weight_quant_delay (int) - 在训练和评估期间权重量化后的步数。

异常:
  • TypeError - weight_quant_delay 数据类型不是int。

  • ValueError - weight_quant_delay 小于0。

set_weight_quant_dtype(weight_quant_dtype)[源代码]

设置量化感知训练参数 config 的weight_quant_dtype值。

参数:
  • weight_quant_dtype (QuantDtype) - 权重量化数据类型。

异常:
  • TypeError - weight_quant_dtype 数据类型不是QuantDtype。

  • ValueError - weight_quant_dtype 不是 QuantDtype.INT8

set_weight_symmetric(weight_symmetric)[源代码]

设置量化感知训练参数 config 的weight_symmetric值。

参数:
  • weight_symmetric (bool) - 量化算法是否使用权重对称。如果为True,则基于对称,否则基于不对称。

异常:
  • TypeError - weight_symmetric 数据类型不是bool。