mindspore_gs.ptq.BaseQuantForCausalLM

View Source On Gitee
class mindspore_gs.ptq.BaseQuantForCausalLM[source]

Base Class for Causal Language Model Quantization

This is the base class that defines the standard interface for all quantized causal language model implementations. It provides the fundamental structure and required methods that must be implemented by all derived classes.

The class implements a registry mechanism that allows different model frameworks to register their implementations. This enables the automatic model detection and selection functionality provided by AutoQuantForCausalLM.

Examples

>>> from mindspore_gs.ptq.models.base_model import BaseQuantForCausalLM
>>>
>>> # A custom model implementation
>>> class MyCustomQuantModel(BaseQuantForCausalLM):
>>>     pass
calibrate(ptq_config, layers_policy, datasets, **kwargs)[source]

Calibrate and quantize the model.

This is an abstract method that must be implemented by derived classes. It should handle the model calibration process using calibration datasets and apply quantization according to the provided configuration.

Parameters
  • ptq_config (PTQConfig) – Configuration for post-training quantization.

  • layers_policy (dict) – Policy for different layer quantization strategies.

  • datasets (Dataset) – Calibration dataset for quantization.

  • **kwargs – Additional keyword arguments.

Raises

NotImplementedError – This method must be implemented by subclasses.

fake_quant(ptq_config, layers_policy, quant_safetensors_path='')[source]

Apply fake quantization to the model.

This method applies fake quantization to the model, which is useful for validating quantization effects without actually converting to integer operations.

Parameters
  • ptq_config (PTQConfig) – Configuration for post-training quantization.

  • layers_policy (dict) – Policy for different layer quantization strategies.

  • quant_safetensors_path (str, optional) – Path to quantized safetensors. Defaults to "".

Raises

NotImplementedError – This method must be implemented by subclasses.

forward(input_ids, max_new_tokens=1)[source]

Perform forward pass through the model.

This is an abstract method that must be implemented by derived classes. It should handle the forward pass logic for model inference.

Parameters
  • input_ids (Tensor) – Input token IDs for the model.

  • max_new_tokens (int, optional) – Maximum number of tokens to generate. Defaults to 1.

Returns

Forward pass results.

Raises

NotImplementedError – This method must be implemented by subclasses.

classmethod from_pretrained(**kwargs)[source]

Create a model instance from pretrained weights.

This is an abstract method that must be implemented by derived classes. It should handle loading pretrained model weights and configuration.

Parameters

**kwargs (dict) – Arbitrary keyword arguments for model creation.

Returns

BaseQuantForCausalLM. An instance of the quantized model.

Raises

NotImplementedError – This method must be implemented by subclasses.

static get_model_hub_registry()[source]

Get the registry of all registered model hubs.

Returns

dict[str, type]. Dictionary mapping model hub names to their

respective class implementations.

static reg_model_hub(alias=None)[source]

Decorator for registering model hub implementations.

This decorator registers a class as a model hub implementation, making it available for automatic detection and selection.

Parameters

alias (str, optional) – Alternative name for the model hub. If not provided, the class name will be used. Defaults to None.

Returns

function. Decorator function that registers the class.

Examples

>>> @BaseQuantForCausalLM.reg_model_hub("qwen3")
>>> class QWen3QuantModel(BaseQuantForCausalLM):
>>>     pass
save_quantized(save_path)[source]

Save the quantized model to disk.

This is an abstract method that must be implemented by derived classes. It should handle saving the quantized model weights and configuration.

Parameters

save_path (str) – Path where the quantized model should be saved.

Raises

NotImplementedError – This method must be implemented by subclasses.