Source code for mindspore.compression.quant.quant_utils

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Quantization utils."""

import numpy as np


__all__ = ["load_nonquant_param_into_quant_net"]


def cal_quantization_params(input_min,
                            input_max,
                            data_type,
                            num_bits=8,
                            symmetric=False,
                            narrow_range=False):
    r"""
    Calculate quantization params for scale and zero point.

    Args:
        input_min (numpy.ndarray): The dimension of channel or 1.
        input_max (numpy.ndarray): The dimension of channel or 1.
        data_type (numpy type) : Can be numpy int8, numpy uint8.
        num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.

    Returns:
        scale (numpy.ndarray): quantization param.
        zero point (numpy.ndarray): quantization param.
    """
    input_max = np.maximum(0.0, input_max)
    input_min = np.minimum(0.0, input_min)

    if input_min.shape != input_max.shape:
        raise ValueError("input min shape should equal to input max.")
    if len(input_min.shape) > 1:
        raise ValueError("input min and max shape should be one dim.")
    if (input_min > input_max).all():
        raise ValueError("input_min min should less than input max.")
    if (input_max == input_min).all():
        return np.ones(input_min.shape), np.zeros(input_min.shape)

    if data_type == np.int8:
        quant_min = 0 - 2 ** (num_bits - 1)
        quant_max = 2 ** (num_bits - 1) - 1
    elif data_type == np.uint8:
        quant_min = 0
        quant_max = 2 ** num_bits - 1
    else:
        raise ValueError("Unsupported datatype({})".format(data_type))
    if narrow_range:
        quant_min = quant_min + 1

    # calculate scale
    if symmetric:
        input_max = np.maximum(-input_min, input_max)
        input_min = -input_max
    scale = (input_max - input_min) / (quant_max - quant_min)

    # calculate zero point
    if symmetric:
        zp = np.zeros(input_min.shape)
    else:
        zp_double = quant_min - input_min / scale
        zp = np.floor(zp_double + 0.5)

    return scale, zp


def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
    r"""
    Calculate int8/uint8 weight from fp32. the formula is defined as:

    .. math::
        int8/uint8 = round(float/scale) + offset

    Args:
        data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
        scale (numpy.ndarray): The dimension of channel or 1.
        zero_point (numpy.ndarray): The dimension of channel or 1.
        data_type (numpy type) : Can be numpy int8, numpy uint8.
        num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.

    Returns:
        weight (numpy.ndarray): The dimension of channel or 1.
    """
    if scale.shape != zero_point.shape:
        raise ValueError("`scale` and `zero_point` should have the same shape.")
    if scale.shape[0] < 0:
        raise ValueError("`scale` and `zero_point` shape should greater than zero.")
    if len(scale.shape) >= 1 and scale.shape[0] > 1:
        # for perchannel
        if scale.shape[0] == data.shape[0]:
            # `Conv2d` or `Dense` op weight
            shape_list = [-1] + [1] * len(data.shape[1:])
            scale = scale.reshape(shape_list)
            zero_point = zero_point.reshape(shape_list)
        elif scale.shape[0] == data.shape[1]:
            # `DepthwiseConv2d` op weight
            shape_list = [1, -1] + [1] * len(data.shape[2:])
            scale = scale.reshape(shape_list)
            zero_point = zero_point.reshape(shape_list)
        else:
            raise ValueError("Unsupported weight shape({})".format(data.shape))

    if data_type == np.int8:
        quant_min = 0 - 2 ** (num_bits - 1)
        quant_max = 2 ** (num_bits - 1) - 1
    elif data_type == np.uint8:
        quant_min = 0
        quant_max = 2 ** num_bits - 1
    else:
        raise ValueError("Unsupported weight datatype({})".format(data_type))
    if narrow_range:
        quant_min = quant_min + 1

    weight_int = np.round((data / scale) + zero_point)
    weight_int[weight_int > quant_max] = quant_max
    weight_int[weight_int < quant_min] = quant_min
    return weight_int

def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
    """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
    minq = cell.minq.data.asnumpy()
    maxq = cell.maxq.data.asnumpy()
    op = cell.fake_quant_infer

    scale, zp = cal_quantization_params(
        minq, maxq, data_type,
        num_bits=op.num_bits,
        symmetric=op.symmetric,
        narrow_range=op.narrow_range)
    return scale, zp, maxq, minq


def scale_zp_from_data(op, minq, maxq, data_type):
    r"""
    Get calculate quantization params for scale and zero point.

    Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.

    Args:
        op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
            `mindspore.ops.operation.FakeQuantPerChannel`
        minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
        maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
        data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.

    Returns:
        scale (numpy.ndarray): quantization param.
        zero point (numpy.ndarray): quantization param.
    """
    minq = minq.data.asnumpy()
    maxq = maxq.data.asnumpy()

    scale, zp = cal_quantization_params(
        minq, maxq, data_type,
        num_bits=op.num_bits,
        symmetric=op.symmetric,
        narrow_range=op.narrow_range)
    return scale, zp


def scale_zp_max_min_from_data(op, minq, maxq, data_type):
    """Get calculate quantization params for scale, zero point, max and min."""
    minq = minq.data.asnumpy()
    maxq = maxq.data.asnumpy()

    scale, zp = cal_quantization_params(
        minq, maxq, data_type,
        num_bits=op.num_bits,
        symmetric=op.symmetric,
        narrow_range=op.narrow_range)
    return scale, zp, maxq, minq


def fold_batchnorm(weight, cell_quant):
    r"""
    Fold the batchnorm in `Conv2dBnFoldQuant` to weight.

    Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.

    Args:
        weight (numpy.ndarray): Weight of `cell_quant`.
        cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.

    Returns:
        weight (numpy.ndarray): Folded weight.
        bias (numpy.ndarray): Folded bias.
    """
    variance = cell_quant.moving_variance.data.asnumpy()
    mean = cell_quant.moving_mean.data.asnumpy()
    gamma = cell_quant.gamma.data.asnumpy()
    beta = cell_quant.beta.data.asnumpy()
    epsilon = cell_quant.eps
    sigma = np.sqrt(variance + epsilon)

    if gamma.shape[0] == weight.shape[0]:
        # `Conv2d` or `Dense` op weight
        shape_list = [-1] + [1] * len(weight.shape[1:])
        _gamma = gamma.reshape(shape_list)
        _sigma = sigma.reshape(shape_list)
    elif gamma.shape[0] == weight.shape[1]:
        # `DepthwiseConv2d` op weight
        shape_list = [1, -1] + [1] * len(weight.shape[2:])
        _gamma = gamma.reshape(shape_list)
        _sigma = sigma.reshape(shape_list)
    else:
        raise ValueError("Unsupported weight shape({})".format(weight.shape))

    weight = weight * _gamma / _sigma
    bias = beta - gamma * mean / sigma
    return weight, bias


def without_fold_batchnorm(weight, cell_quant):
    r"""
    Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.

    Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.

    Args:
        weight (numpy.ndarray): Weight of `cell_quant`.
        cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.

    Returns:
        weight (numpy.ndarray): whihout folded weight.
        bias (numpy.ndarray): without folded bias.
    """
    variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
    mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
    gamma = cell_quant.batchnorm.gamma.data.asnumpy()
    beta = cell_quant.batchnorm.beta.data.asnumpy()
    epsilon = cell_quant.batchnorm.eps
    sigma = np.sqrt(variance + epsilon)

    if gamma.shape[0] == weight.shape[0]:
        # `Conv2d` or `Dense` op weight
        shape_list = [-1] + [1] * len(weight.shape[1:])
        _gamma = gamma.reshape(shape_list)
        _sigma = sigma.reshape(shape_list)
    elif gamma.shape[0] == weight.shape[1]:
        # `DepthwiseConv2d` op weight
        shape_list = [1, -1] + [1] * len(weight.shape[2:])
        _gamma = gamma.reshape(shape_list)
        _sigma = sigma.reshape(shape_list)
    else:
        raise ValueError("Unsupported weight shape({})".format(weight.shape))

    weight = weight * _gamma / _sigma
    bias = beta - gamma * mean / sigma
    return weight, bias


[docs]def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): r""" Load fp32 model parameters into quantization model. Args: quant_model: quantization model. params_dict: parameter dict that stores fp32 parameters. quant_new_params: parameters that exist in quantitative network but not in unquantitative network. Returns: None """ iterable_dict = { 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), 'moving_variance': iter( [item for item in params_dict.items() if item[0].endswith('moving_variance')]), 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) } for name, param in quant_model.parameters_and_names(): key_name = name.split(".")[-1] if key_name not in iterable_dict.keys(): if quant_new_params is not None and key_name in quant_new_params: continue raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") value_param = next(iterable_dict[key_name], None) if value_param is not None: param.set_data(value_param[1].data) print(f'init model param {name} with checkpoint param {value_param[0]}')