lite_boost.parallel.ParallelManager

View Source On AtomGit
class lite_boost.parallel.ParallelManager[source]

Modify a supported model in-place for distributed parallel inference.

ParallelManager wraps a supported model or pipeline and patches it in-place for multi-NPU parallel inference. Two parallelism strategies are applied automatically based on the detected model components:

  • Ulysses Sequence Parallel (USP) for DiT models — patches the forward method and attention layers to enable sequence-dimension parallelism via all_to_all communication. Each device holds full model weights and operates on a slice of the sequence.

  • Data Parallel (DP) temporal tiling for VAE models — replaces vae.encode and vae.decode with DP temporal slicing versions that split the video along the temporal dimension into overlapping chunks, distribute them across devices, and gather results.

When a pipeline object (e.g., WanT2V) is passed, both strategies are applied: USP for the DiT model and DP for the VAE. When a raw WanModel is passed, only USP is applied.

The model is modified in-place and returned as-is, so all existing attributes and methods (.to, .cpu, .eval, etc.) continue to work normally.

The internal patching is dispatched by lite_boost.model.setup_model(), which detects the model type and applies the corresponding adapter (e.g., replacing flash_attention with an NPU-compatible version, inserting all_to_all communication pairs around the attention layers, and binding DP temporal tiling to the VAE encode/decode).

Parameters:

target (object) – A supported pipeline object to be parallelized. The model type is auto-detected via lite_boost.model.setup_model(). Supported classes include WanT2V and WanTI2V.

Returns:

object, the same instance modified in-place with USP-patched forward and attention methods (for DiT) and DP-patched encode/decode methods (for VAE).

Raises:

RuntimeError – If the model type is not supported by lite_boost.

Examples

>>> import os
>>> os.environ["RANK"] = "0"
>>> os.environ["WORLD_SIZE"] = "1"
>>> from lite_boost.parallel import initialize_usp, ParallelManager
>>> from wan.textimage2video import WanTI2V
>>> initialize_usp()
>>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...)
>>> ParallelManager(pipe)