lite_boost.parallel.ParallelManager
- class lite_boost.parallel.ParallelManager[source]
Modify a supported model in-place for distributed parallel inference.
ParallelManagerwraps a supported model or pipeline and patches it in-place for multi-NPU parallel inference. Two parallelism strategies are applied automatically based on the detected model components:Ulysses Sequence Parallel (USP) for DiT models — patches the
forwardmethod and attention layers to enable sequence-dimension parallelism viaall_to_allcommunication. Each device holds full model weights and operates on a slice of the sequence.Data Parallel (DP) temporal tiling for VAE models — replaces
vae.encodeandvae.decodewith DP temporal slicing versions that split the video along the temporal dimension into overlapping chunks, distribute them across devices, and gather results.
When a pipeline object (e.g.,
WanT2V) is passed, both strategies are applied: USP for the DiT model and DP for the VAE. When a rawWanModelis passed, only USP is applied.The model is modified in-place and returned as-is, so all existing attributes and methods (
.to,.cpu,.eval, etc.) continue to work normally.The internal patching is dispatched by
lite_boost.model.setup_model(), which detects the model type and applies the corresponding adapter (e.g., replacingflash_attentionwith an NPU-compatible version, insertingall_to_allcommunication pairs around the attention layers, and binding DP temporal tiling to the VAE encode/decode).- Parameters:
target (object) – A supported pipeline object to be parallelized. The model type is auto-detected via
lite_boost.model.setup_model(). Supported classes includeWanT2VandWanTI2V.- Returns:
object, the same instance modified in-place with USP-patched forward and attention methods (for DiT) and DP-patched encode/decode methods (for VAE).
- Raises:
RuntimeError – If the model type is not supported by lite_boost.
Examples
>>> import os >>> os.environ["RANK"] = "0" >>> os.environ["WORLD_SIZE"] = "1" >>> from lite_boost.parallel import initialize_usp, ParallelManager >>> from wan.textimage2video import WanTI2V >>> initialize_usp() >>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...) >>> ParallelManager(pipe)