Model Quantization
Overview
MindSpore is an all-scenario AI framework. When a model is deployed on the device or other lightweight devices, it may be subject to memory, power consumption, and latency. Therefore, the model needs to be compressed before deployment.
MindSpore Golden Stick provides the model compression capability of MindSpore. MindSpore Golden Stick is a set of model compression algorithms jointly designed and developed by Huawei Noah's Ark team and Huawei MindSpore team. It provides a series of model compression algorithms for MindSpore, supporting quantization modes such as A16W8, A16W4, A8W8, and KVCache. For details, see MindSpore Golden Stick.
Basic Model Quantization Process
To help you understand the basic model quantization process of MindSpore Golden Stick, this section uses the quantization algorithm as an example to describe the basic usage.
Procedure
The MindSpore Golden Stick quantization algorithm can be divided into two phases: quantization phase and deployment phase. The quantization phase is completed before deployment. The main tasks are as follows: collecting weight distribution, computing quantization parameters, quantizing weight data, and inserting dequantization nodes. The deployment phase refers to the process of using the MindSpore framework to perform inference on the quantized model in the production environment.
MindSpore Golden Stick mainly uses PTQConfig
to customize quantization and deployment, and uses the apply
and convert
APIs to implement quantization and deployment. You can configure whether to quantize the weight, activation, and KVCache, and configure the quantization bit in PTQConfig
. In addition, you can configure the data calibration policy. For details, see PTQConfig Description.
The quantization procedure of MindSpore Golden Stick is as follows:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, dtype
from mindformers.modules import Linear
from mindspore_gs.common import BackendTarget
from mindspore_gs.ptq import PTQMode, PTQConfig
from mindspore_gs.ptq.ptq import PTQ
from mindspore.dataset import GeneratorDataset
class SimpleNet(nn.Cell):
class DecoderCell(nn.Cell):
def __init__(self, linear):
super().__init__()
self.linear = linear
def construct(self, *args, **kwargs):
return self.linear(*args, **kwargs)
def __init__(self, foo_seq_length=1024):
super().__init__()
self.foo_seq_length = foo_seq_length
linear = Linear(in_channels=foo_seq_length, out_channels=foo_seq_length, weight_init="ones")
self.decoder = SimpleNet.DecoderCell(linear)
def construct(self, x):
return self.decoder(x)
def generate(self, input_ids, do_sample=False, max_new_tokens=1):
input_ids = np.pad(input_ids, ((0, 0), (0, self.foo_seq_length - input_ids.shape[1])), 'constant',
constant_values=0)
return self.construct(Tensor(input_ids, dtype=dtype.float16))
def create_foo_ds(repeat=1):
class SimpleIterable:
def __init__(self, repeat=1):
self._index = 0
self.data = []
for _ in range(repeat):
self.data.append(np.array([[1, 1, 1]], dtype=np.int32))
def __next__(self):
if self._index >= len(self.data):
raise StopIteration
item = (self.data[self._index],)
self._index += 1
return item
def __iter__(self):
self._index = 0
return self
def __len__(self):
return len(self.data)
return GeneratorDataset(source=SimpleIterable(repeat), column_names=["input_ids"])
net = SimpleNet() # The float model that needs to be quantized
ds = create_foo_ds(1)
cfg = PTQConfig(mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND, weight_quant_dtype=dtype.int8)
ptq = PTQ(cfg)
ptq.apply(net, datasets=ds)
ptq.convert(net)
ms.save_checkpoint(net.parameters_dict(), './simplenet_ptq.ckpt')
Use nn.Cell to define the network. After the model is trained, the floating-point weights of the model are obtained. During inference, the floating-point weights of the model are loaded. The preceding example simplifies the process by directly creating a network and quantizing the network using the initial floating-point weights.
Use PTQConfig to set the mode to quantization and backend to Ascend for 8-bit quantization of the weights. For details, see PTQConfig Description.
Use the apply API to convert the network into a fake-quantized network and collect statistics on the quantization objects according to
PTQConfig
.Use the convert API to perform real quantization on the fake-quantized network obtained in the previous step to obtain the quantized network.
After the quantization is complete, you can use the quantized model for inference. The procedure is as follows:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, dtype
from mindformers.modules import Linear
from mindspore_gs.common import BackendTarget
from mindspore_gs.ptq import PTQMode, PTQConfig
from mindspore_gs.ptq.ptq import PTQ
from mindspore.dataset import GeneratorDataset
class SimpleNet(nn.Cell):
class DecoderCell(nn.Cell):
def __init__(self, linear):
super().__init__()
self.linear = linear
def construct(self, *args, **kwargs):
return self.linear(*args, **kwargs)
def __init__(self, foo_seq_length=1024):
super().__init__()
self.foo_seq_length = foo_seq_length
linear = Linear(in_channels=foo_seq_length, out_channels=foo_seq_length, weight_init="ones")
self.decoder = SimpleNet.DecoderCell(linear)
def construct(self, x):
return self.decoder(x)
def generate(self, input_ids, do_sample=False, max_new_tokens=1):
input_ids = np.pad(input_ids, ((0, 0), (0, self.foo_seq_length - input_ids.shape[1])), 'constant',
constant_values=0)
return self.construct(Tensor(input_ids, dtype=dtype.float16))
net = SimpleNet()
cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=dtype.int8)
ptq = PTQ(cfg)
ptq.apply(net)
ptq.convert(net)
ms.load_checkpoint('./simplenet_ptq.ckpt', net)
input = Tensor(np.ones((5, 1024), dtype=np.float32), dtype=dtype.float32)
output = net(input)
print(output)
Use PTQConfig to set the mode to deployment and backend to Ascend for 8-bit quantization of the weights. For details, see PTQConfig Description.
Use the apply and convert APIs to convert the network into a quantized network. In the deployment phase, no information statistics are collected or quantization computing is performed. Only the network structure is converted into a quantized network.
Load the quantized weights to the quantized network for inference.
PTQConfig Description
You can customize the PTQConfig to enable different quantization capabilities. For details about PTQConfig, see the API document. The following lists the configuration examples of these algorithms:
A indicates activation, W indicates weight, C indicates KVCache, and the number indicates the bit. For example, A16W8 indicates that the activation is quantized to float16 and the weight is quantized to int8.
A16W8 weight quantization
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=None, kvcache_quant_dtype=None, outliers_suppression=OutliersSuppressionType.NONE)
A8W8 quantization
A8W8 quantization is based on the SmoothQuant algorithm. PTQConfig provides the outliers_suppression field to specify whether to perform the smooth operation.
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, kvcache_quant_dtype=None, outliers_suppression=OutliersSuppressionType.SMOOTH)
KVCache int8 quantization
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=None, act_quant_dtype=None, kvcache_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.NONE)
Examples
PTQ Examples
The following provides the complete process of quantizing and deploying the post-training quantization (PTQ) algorithm on the Llama2 network:
PTQ algorithm: supports 8-bit weight quantization, 8-bit full quantization, and KVCacheInt8 quantization. SmoothQuant can be used to improve the quantization precision. Combined quantization algorithms of different algorithms are supported to improve the quantization inference performance.
Perceptual Quantization Training Examples
SimQAT algorithm: A basic quantization aware algorithm based on the fake quantization technology.
SLB quantization algorithm: A non-linear low-bit quantization aware algorithm.
Pruning Examples
SCOP pruning algorithm: A structured weight pruning algorithm.