Source code for mindspore_gs.ptq.ptq.quant

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PTQ algorithm."""
from functools import partial
from typing import List, Union, Tuple, Optional
from collections import OrderedDict
import time
import gc
import os
import copy
import tqdm
from mindspore import dtype, get_context, PYNATIVE_MODE
from mindspore.nn import Cell
from mindspore.nn.utils import no_init_parameters
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore_gs.comp_algo import CompAlgo
from mindspore_gs.common import logger
from mindspore_gs.common.utils import offload_network, value_check
from mindspore_gs.ptq.processor import transform_network_inplace
from mindspore_gs.ptq.ptq_config import PTQConfig, PTQMode, OutliersSuppressionType, PrecisionRecovery
from mindspore_gs.ptq.context import InnerPTQConfig, PTQApproach
from mindspore_gs.ptq.network_helpers import NetworkHelper
from mindspore_gs.ptq.ptq.wrapper_cell import WrapperCell, SearchInputs
from mindspore_gs.ptq.processor import Processor
from .algorithm import Algorithm
from .algorithms import LinearSmoothQuant, LinearAutoSmoother, LinearClipper, Quantizer


class InputCatcher(Cell):
    """input catcher"""

    def __init__(self):
        super().__init__()
        self.handler = None
        self.args = []
        self.kwargs = []
        self.old_construct = None
        self.patched = False

    def patch(self, handler):
        """patch"""
        if self.patched:
            raise RuntimeError("Only support patch one cell for one time. please invoke recover before invoking patch "
                               "again.")
        self.handler = handler
        self.old_construct = handler.construct
        self.handler.construct = partial(InputCatcher.construct, self)
        self.patched = True

    def recover(self):
        """recover"""
        if self.patched and self.handler and self.old_construct:
            self.handler.construct = self.old_construct
        self.patched = False

    def construct(self, *args, **kwargs):
        """construct"""
        self.args.append(list(args))
        self.kwargs.append(kwargs)
        raise GeneratorExit("already catch first layer inputs, do not need continue.")


[docs]class PTQ(CompAlgo): """ Implementation of PTQ algorithm which supports the combination quantization of activation, weight, and kvcache. Args: config(:class:`mindspore_gs.ptq.PTQConfig`, optional): config for PTQ, default is ``None``. layer_policies(OrderedDict, optional): quantization strategy for layers, default is ``None``. The key of `layer_policies` is regular string to match the layer name, the value of `layer_policies` is :class:`mindspore_gs.ptq.PTQConfig`. Raises: TypeError: If `config` type is not PTQConfig when it's not ``None``. TypeError: If any value in `layer_policies` type is not PTQConfig when it's not ``None``. ValueError: If not PYNATIVE mode when mode in config is PTQMode.QUANTIZE. ValueError: If act_quant_dtype is int8 and weight_quant_dtype is None. TypeError: If layer_policies is not an OrderedDict. Examples: >>> import mindspore_gs >>> from mindspore_gs.ptq import PTQ >>> from mindspore_gs.ptq import PTQConfig >>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper >>> from mindformers.tools.register.config import MindFormerConfig >>> from mindformers import LlamaForCausalLM, LlamaConfig >>> from mindspore_gs.common.gs_enum import BackendTarget >>> from mindspore import dtype as msdtype >>> mf_yaml_config_file = "/path/to/mf_yaml_config_file" >>> mfconfig = MindFormerConfig(mf_yaml_config_file) >>> helper = MFLlama2Helper(mfconfig) >>> backend = BackendTarget.ASCEND >>> ptq_config = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, opname_blacklist=["w2", "lm_head"], ... weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, ... outliers_suppression=OutliersSuppressionType.SMOOTH) >>> attn_policy = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, ... weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, ... outliers_suppression=OutliersSuppressionType.NONE) >>> layer_policy = OrderedDict({r'.*Attention.wo.*': attn_policy}) >>> ptq = PTQ(ptq_config, layer_policy) >>> network = LlamaForCausalLM(LlamaConfig(**mfconfig.model.model_config)) >>> fake_quant_net = ptq.apply(network, helper) >>> quant_net = ptq.convert(fake_quant_net) >>> ptq.summary(quant_net) """ def __init__(self, config: Union[dict, PTQConfig] = None, layer_policies=None): super().__init__() if config is not None: if not isinstance(config, PTQConfig): raise TypeError(f'Shall init PTQ with PTQConfig, bug got {type(config)}') self._config = config else: self._config = PTQConfig() if layer_policies is None: self.layer_policies = OrderedDict() else: self.layer_policies = layer_policies # convert PTQConfig to InnerConfig to add inner parameters self._config = InnerPTQConfig().inner_config(self._config, approach=PTQApproach.PTQ) self._generate_func = None logger.info(f"Config for PTQ: {self._config}") PTQ._ptq_config_check(self._config) self._layer_policies_check() self.pipeline: List[Algorithm] = [] self.decoder_layers: list[Cell] = [] self.decoder_layer_types: list = [] self.context_mode = get_context("mode") self._target_layer_type = () self._build_pipeline() self._load_mindformers_plugin() def _append_algorithm(self, name, algorithm: Algorithm): logger.info(f"append {name} to pipeline.") self.pipeline.append(algorithm) def _build_pipeline(self): """build pipline""" smoothquant = LinearSmoothQuant(self._config, self.layer_policies) clipper = LinearClipper(self._config, self.layer_policies) awq = LinearAutoSmoother(self._config, self.layer_policies) quantizer = Quantizer(self._config, self.layer_policies) self._append_algorithm('LinearSmoothQuant', smoothquant) self._append_algorithm('LinearAutoSmoother', awq) self._append_algorithm('LinearClipper', clipper) self._append_algorithm('Quantizer', quantizer) def _load_mindformers_plugin(self): """_load_mindformers_plugin""" for algorithm in self.pipeline: algorithm.load_mindformers_plugin() self._target_layer_type += algorithm.target_layer_type() from mindformers.models.llama.llama_transformer import LLamaDecodeLayer self.decoder_layer_types.append(LLamaDecodeLayer) try: from mindformers.experimental.infer.core.transformer import ParallelTransformerLayer self.decoder_layer_types.append(ParallelTransformerLayer) except ImportError: pass try: from research.llama3_1.infer.transformer import ParallelTransformerLayer as LlamaParallelTransformerLayer from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer self.decoder_layer_types.append(DeepseekV3DecodeLayer) self.decoder_layer_types.append(LlamaParallelTransformerLayer) except ImportError: pass try: from research.telechat2.infer.telechat_transformers import TelechatParallelTransformerLayer self.decoder_layer_types.append(TelechatParallelTransformerLayer) except ImportError: pass def generate(network, input_ids, helper=None): if isinstance(helper, NetworkHelper): return helper.generate(network, input_ids, do_sample=False, max_new_tokens=1) return network.generate(input_ids, do_sample=False, max_new_tokens=1) self._generate_func = generate def _get_decoder_layers(self, network: Cell): """ Get decoder layers from network. Args: network (nn.Cell): Network to get decoder layers. Returns: A list of tuples (cell_name, `Cell`) as decoder layers of network. """ value_check('network', network, Cell) class NetworkWalker(Processor): def __init__(self, decoder_layer_types_): self.layers = [] self._decoder_layer_types = decoder_layer_types_ def process_cell(self, cell_name: str, cell: Cell) -> Tuple[Cell, bool]: if isinstance(cell, self._decoder_layer_types): self.layers.append((cell_name, cell)) return cell, True return cell, False walker = NetworkWalker(tuple(self.decoder_layer_types)) walker.process(network) if walker.layers: self.decoder_layers = walker.layers return self.decoder_layers = [("network", network)] logger.warning( f"No decoder layer found in network. Visible decoder layer types: {self.decoder_layer_types}, " "please modify PTQ.decoder_layer_types before invoking apply method. If not, PTQ will take lots of memory.") @staticmethod def _ptq_config_check(config): """_ptq_config_check""" use_w8 = config.weight_quant_dtype == dtype.int8 use_a8 = config.act_quant_dtype == dtype.int8 if config.outliers_suppression is None and use_a8 and use_w8: logger.warning("When outliers_suppression is None, A8W8 algorithm accuracy is expected to decline.") if config.weight_quant_dtype is None and use_a8: raise ValueError("PTQ algorithm do not support only quant activation.") use_ptq_or_awq = (config.outliers_suppression == OutliersSuppressionType.AWQ or config.precision_recovery == PrecisionRecovery.GPTQ) if use_w8 and use_a8 and use_ptq_or_awq: raise ValueError("AWQ algorithm and GPTQ algorithm do not support quant activation.") use_a8w8_only = use_a8 and use_w8 and config.kvcache_quant_dtype is None use_osl = config.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_LITE if not use_a8w8_only and use_osl: raise ValueError("OUTLIER_SUPPRESSION_LITE algorithm only support W8A8 quant.") use_w4 = config.weight_quant_dtype == dtype.qint4x2 use_c8 = config.kvcache_quant_dtype == dtype.int8 if use_w4 and use_c8: raise ValueError("PTQ algorithm only support quant weight in int4 alone." "Please not to use with c8 at the same time.") def _layer_policies_check(self): """_layer_policies_check""" import re if not isinstance(self.layer_policies, OrderedDict): raise TypeError(f'layer_policies should be an OrderedDict, bug got {type(self.layer_policies)}.') if any(not isinstance(key, str) for key in self.layer_policies.keys()): raise TypeError(f'all key of layer_policies should be a string.') try: for key, config_ in self.layer_policies.items(): if config_: re.compile(key) if not isinstance(config_, PTQConfig): raise TypeError(f'The type of value in layer_policies should be PTQConfig,' f'but got {type(config_)}') if config_.mode != self._config.mode: logger.warning(f'The mode={config_.mode} in {key} layer policy different from ' f'mode={self._config.mode} in network policy, PTQ algorithm use network policy ' f'mode to quant.') config_.mode = self._config.mode if config_.backend != self._config.backend: logger.warning(f'The backend={config_.backend} in {key} layer policy different from ' f'backend={self._config.backend} in network policy, PTQ algorithm use network ' f'policy backend to quant.') config_.backend = self._config.backend self.layer_policies[key] = InnerPTQConfig().inner_config(config_, approach=PTQApproach.PTQ) PTQ._ptq_config_check(self.layer_policies[key]) except re.error: raise TypeError('The regular string of layer_policies not correct, please check and try again.') \ from re.error # pylint: disable=arguments-differ # pylint: disable=unused-argument
[docs] def apply(self, network: Cell, network_helper: NetworkHelper = None, datasets=None, **kwargs) -> Cell: """ Define how to add fake quantizer to `network`. Args: network (Cell): Network to be fake quantized. network_helper (NetworkHelper): Utils for decoupling algorithm with network framework. datasets (Dataset): Datasets for calibrating. Returns: fake quantized network. Raises: RuntimeError: If PTQ is not well inited. TypeError: If input `network` is not a Cell. ValueError: If input datasets is None. """ self._config.update_comm_info() self._get_decoder_layers(network) if self._config.mode == PTQMode.DEPLOY: logger.info("unset environ ENFORCE_EAGER and MS_JIT because of PTQMode.DEPLOY mode") for i in tqdm.tqdm(range(len(self.decoder_layers)), desc="Running PTQ Deploy..."): layer_name, layer = self.decoder_layers[i] for processor in self.pipeline: with no_init_parameters(): processor.replace(layer_name, layer) processor.deploy(layer_name, layer) network.update_parameters_name() return network os.environ['ENFORCE_EAGER'] = 'true' logger.info("set environ ENFORCE_EAGER=true and MS_JIT=0 because of PTQMode.QUANTIZE mode") if get_context("mode") != PYNATIVE_MODE: raise ValueError("In QUANTIZE phase, please set mode=PYNATIVE_MODE.") if not datasets: raise ValueError("please provide dataset when use PTQ quant to quantize network.") logger.info(f"Visible decoder layer types: {self.decoder_layer_types}. If decoder layer type of target network " "not in list, please modify PTQ.decoder_layer_types before invoking apply method.") logger.info("Analysis network structure.") start_time = time.time() logger.info(f"Catching inputs for first decoder layer with {datasets.get_dataset_size()} datasets samples.") catcher, network = self._get_first_layer_input(network, datasets, network_helper) all_args = catcher.args all_kwargs = catcher.kwargs logger.info(f"_get_first_layer_input time cost {time.time() - start_time}") start_time = time.time() logger.info(f"get_decoder_layers time cost {time.time() - start_time}") for i in tqdm.tqdm(range(len(self.decoder_layers)), desc="Running PTQ..."): logger.info(f"Quantize {i}th decoder layer.") layer_name, layer = self.decoder_layers[i] cur_args, cur_kwargs = copy.deepcopy(all_args), copy.deepcopy(all_kwargs) if self._config.always_use_fp_input_in_processer: for index, (args, kwargs) in enumerate(zip(cur_args, cur_kwargs)): output = layer(*args, **kwargs) if len(self.decoder_layers) > 1: all_args[index][0] = output[0] if isinstance(output, tuple) else output for processor in self.pipeline: processor.replace(layer_name, layer, search_inputs=SearchInputs(layer, cur_args, cur_kwargs)) logger.info("Catching inputs of all Linear in decoder layer.") start_time = time.time() transform_network_inplace(layer, WrapperCell, lambda _, cell: cell.add_hook()) index = 0 for args, kwargs in zip(cur_args, cur_kwargs): output = layer(*args, **kwargs) if len(self.decoder_layers) > 1 and not self._config.always_use_fp_input_in_processer: # FIXME: 'always_use_fp_input_in_processer' is a temporary switch for fixing activation between # layers. This branch may introduces error to the next layer, because previous processors in the # pipeline changes the layer, and thus, gives a inaccurate output. Set the switch to True to # avoid this issue. The switch should be removed after the issue is fixed. -- @tongl2 all_args[index][0] = output[0] if isinstance(output, tuple) else output index += 1 transform_network_inplace(layer, WrapperCell, lambda _, cell: cell.remove_hook()) logger.info(f"{i}th layer output refresh time cost {time.time() - start_time}") processor.process(layer_name, layer) processor.deploy(layer_name, layer) network.update_parameters_name() gc.collect() if self._config.reflash_inputs_after_each_processor: index = 0 for args, kwargs in zip(cur_args, cur_kwargs): all_args[index][0] = layer(*args, **kwargs) index += 1 start_time = time.time() offload_network(layer) gc.collect() logger.info(f"{i}th layer offload network time cost {time.time() - start_time}") return network
def _get_first_layer_input(self, network: Cell, ds=None, helper=None): """get first layer input""" catcher = InputCatcher() catcher.patch(self.decoder_layers[0][1]) if not ds: raise ValueError("PTQ need dataset to calibrate, please provide dataset.") total_count = ds.get_dataset_size() data_count = 1 for _, ds_item in enumerate(ds.create_dict_iterator()): logger.info(f"Calibrating: dataset count: {data_count}/{total_count}") input_ids = ds_item['input_ids'].asnumpy() try: self._generate_func(network, input_ids, helper) except GeneratorExit: if hasattr(network, "block_mgr") and network.block_mgr: network.block_mgr.clear_cache() data_count += 1 catcher.recover() offload_network(network) return catcher, network
[docs] def convert(self, net_opt: Cell, ckpt_path="") -> Cell: """ Define how to convert a compressed network to a standard network before exporting. Args: net_opt (Cell): Network to be converted which is transformed by `RoundToNearest.apply`. ckpt_path (str): Path to checkpoint file for `net_opt`. Default is ``""``, which means not loading checkpoint file to `net_opt`. Returns: An instance of Cell represents quantized network. Raises: TypeError: If `net_opt` is not Cell. TypeError: If `ckpt_path` is not string. ValueError: If `ckpt_path` is not empty and invalid. """ if not isinstance(net_opt, Cell): raise TypeError( f'The parameter `net_opt` must be isinstance of Cell, but got {type(net_opt)}.') if not isinstance(ckpt_path, str): raise TypeError( f'The parameter `ckpt_path` must be isinstance of str, but got {type(ckpt_path)}.') real_path = os.path.realpath(ckpt_path) if ckpt_path != "": if os.path.isfile(real_path): param_dict = load_checkpoint(ckpt_path) load_param_into_net(net_opt, param_dict) else: raise ValueError( f'The parameter `ckpt_path` can only be empty or a valid file, but got {real_path}.') self.summary(net_opt) return net_opt
def _summary_target_layer_type(self) -> tuple: return self._target_layer_type def _summary_layer(self, layer_name, layer: Cell) -> Optional[str]: info = self._config.layer_quant_info_collect.get(layer_name) if not info and layer_name.endswith('_layer'): info = self._config.layer_quant_info_collect.get(layer_name[:-7]) if not info and layer_name.endswith('.layer'): info = self._config.layer_quant_info_collect.get(layer_name[:-6]) return info def _summary_title(self): return "Network Quantization Summary" def _summary_desc_name(self): return "quant_type"