mindspore_gs.ptq.PTQ
- class mindspore_gs.ptq.PTQ(config=None, layer_policies=None)[源代码]
- 量化算法PTQ的基本实现,支持激活、权重和kvcache的组合量化。 - 参数:
- config ( - mindspore_gs.ptq.PTQConfig, 可选) - 用于配置PTQ量化算法,默认值为- None。
- layer_policies (OrderedDict, 可选) - 层的量化策略,默认值为 - None。layer_policies 的key为正则字符串用来匹配层的名称,value为- mindspore_gs.ptq.PTQConfig。
 
- 异常:
- TypeError - config 在输入不为 - None时,元素类型不为 PTQConfig。
- TypeError - layer_policies 中任一元素不为 - None时,其类型不为 PTQConfig。
- ValueError - config 中的 mode 是PTQMode.QUANTIZE时非PYNATIVE模式。 
- ValueError - 当act_quant_dtype是int8类型,weight_quant_dtype为None时。 
- TypeError - layer_policies 的类型不为 OrderedDict。 
 
 - 样例: - >>> import mindspore_gs >>> from mindspore_gs.ptq import PTQ >>> from mindspore_gs.ptq import PTQConfig >>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper >>> from mindformers.tools.register.config import MindFormerConfig >>> from mindformers import LlamaForCausalLM, LlamaConfig >>> from mindspore_gs.common.gs_enum import BackendTarget >>> from mindspore import dtype as msdtype >>> mf_yaml_config_file = "/path/to/mf_yaml_config_file" >>> mfconfig = MindFormerConfig(mf_yaml_config_file) >>> helper = MFLlama2Helper(mfconfig) >>> backend = BackendTarget.ASCEND >>> ptq_config = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, opname_blacklist=["w2", "lm_head"], ... weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, ... outliers_suppression=OutliersSuppressionType.SMOOTH) >>> attn_policy = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, ... weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, ... outliers_suppression=OutliersSuppressionType.NONE) >>> layer_policy = OrderedDict({r'.*Attention.wo.*': attn_policy}) >>> ptq = PTQ(ptq_config, layer_policy) >>> network = LlamaForCausalLM(LlamaConfig(**mfconfig.model.model_config)) >>> fake_quant_net = ptq.apply(network, helper) >>> quant_net = ptq.convert(fake_quant_net) >>> ptq.summary(quant_net) - apply(network, network_helper=None, datasets=None, **kwargs)[源代码]
- 将 network 中添加伪量化节点,转换成一个伪量化网络。 - 参数:
- network (Cell) - 待伪量化的网络。 
- network_helper (NetworkHelper) - 网络量化工具,用于解耦算法层和网络框架层。 
- datasets (Dataset) - 校准用的数据集。 
 
- 返回:
- 伪量化后的网络。 
- 异常:
- RuntimeError - 如果当前算法没有有效的初始化。 
- TypeError - network 不是一个 Cell 对象。 
- ValueError - 当datasets为空。 
 
 
 - convert(net_opt, ckpt_path='')[源代码]
- 将量化网络 net_opt 转换为真实量化网络,后续导出用于部署。 - 参数:
- net_opt (Cell) - 经过量化算法apply之后的网络。 
- ckpt_path (str) - 网络的checkpoint file文件路径,默认值为 - "",表示不加载。注意,该参数会在后续版本中被遗弃。
 
- 返回:
- 转换后的网络。 
 
 - fake_quant(network)[源代码]
- 将网络应用伪量化操作。该方法对模型应用伪量化,这对于验证量化效果而不实际转换为整数操作非常有用。 - 参数:
- network (Cell) - 待伪量化的网络。 
 
- 返回:
- 伪量化后的网络。 
- 异常:
- TypeError - network 数据类型不是Cell。 
 
 
 - summary(network)
- 总结并打印网络中每一层的使用的模型压缩策略,如某一个线性层的量化策略为:W8-perchannel-A8-pertensor。 - 参数:
- network (Cell) - 待总结的网络。 
 
- 异常:
- TypeError - network 数据类型不是Cell。