Source code for lite_boost.parallel._manager

#!/usr/bin/env python3
# Copyright 2026 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
r"""
ParallelManager — one-line model parallelization for distributed inference.

This module provides the :func:`initialize_usp` function and the
:class:`ParallelManager` class, which together enable distributed inference
on supported models with minimal code changes. Two parallelism strategies
are applied automatically based on the model type:

- **Ulysses Sequence Parallel (USP)** for DiT models — sequence-dimension
  parallelism via ``all_to_all`` communication around attention layers.
- **Data Parallel (DP) temporal tiling** for VAE models — temporal-dimension
  slicing with overlap, distributed across devices.

Usage:
    >>> import os
    >>> os.environ["RANK"] = "0"
    >>> os.environ["WORLD_SIZE"] = "1"
    >>> from lite_boost.parallel import initialize_usp, ParallelManager
    >>> from wan.textimage2video import WanTI2V
    >>> initialize_usp()
    >>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...)
    >>> ParallelManager(pipe)
"""
import os

import torch
import torch.distributed as dist
import torch_npu


[docs] def initialize_usp(): r""" Initialize the HCCL distributed environment for parallel inference. This function configures the NPU runtime settings and initializes the HCCL distributed process group by reading the following environment variables: - ``RANK``: Local rank of the current process. Default: ``0``. - ``WORLD_SIZE``: Total number of distributed processes. Default: ``1``. - ``MASTER_ADDR``: IP address of the master node. Default: ``"127.0.0.1"``. - ``MASTER_PORT``: Port of the master node. Default: ``29502``. - ``NUM_THREADS``: Number of CPU threads per process. Default: ``24``. If the distributed process group has not been initialized, this function will initialize it with the ``hccl`` backend. After initialization, the NPU device corresponding to ``RANK`` is set as the active device. Note: This function must be called before constructing :class:`ParallelManager`. It is typically invoked at the entry point of a distributed training or inference script. Raises: RuntimeError: If HCCL process group initialization fails. Examples: >>> import os >>> os.environ["RANK"] = "0" >>> os.environ["WORLD_SIZE"] = "1" >>> from lite_boost.parallel import initialize_usp >>> initialize_usp() """ torch.npu.config.allow_internal_format = False torch.npu.set_compile_mode(jit_compile=False) local_rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) master_addr = str(os.getenv("MASTER_ADDR", "127.0.0.1")) port = int(os.getenv("MASTER_PORT", "29502")) torch.set_num_threads(int(os.getenv("NUM_THREADS", "24"))) if not dist.is_initialized(): dist.init_process_group( backend="hccl", init_method=f"tcp://{master_addr}:{port}", world_size=world_size, rank=local_rank, ) torch_npu.npu.set_device(local_rank)
[docs] class ParallelManager: r""" Modify a supported model in-place for distributed parallel inference. :class:`ParallelManager` wraps a supported model or pipeline and patches it in-place for multi-NPU parallel inference. Two parallelism strategies are applied automatically based on the detected model components: - **Ulysses Sequence Parallel (USP)** for DiT models — patches the ``forward`` method and attention layers to enable sequence-dimension parallelism via ``all_to_all`` communication. Each device holds full model weights and operates on a slice of the sequence. - **Data Parallel (DP) temporal tiling** for VAE models — replaces ``vae.encode`` and ``vae.decode`` with DP temporal slicing versions that split the video along the temporal dimension into overlapping chunks, distribute them across devices, and gather results. When a pipeline object (e.g., ``WanT2V``) is passed, both strategies are applied: USP for the DiT model and DP for the VAE. When a raw ``WanModel`` is passed, only USP is applied. The model is modified in-place and returned as-is, so all existing attributes and methods (``.to``, ``.cpu``, ``.eval``, etc.) continue to work normally. The internal patching is dispatched by :func:`lite_boost.model.setup_model`, which detects the model type and applies the corresponding adapter (e.g., replacing ``flash_attention`` with an NPU-compatible version, inserting ``all_to_all`` communication pairs around the attention layers, and binding DP temporal tiling to the VAE encode/decode). Args: target (object): A supported pipeline object to be parallelized. The model type is auto-detected via :func:`lite_boost.model.setup_model`. Supported classes include ``WanT2V`` and ``WanTI2V``. Returns: object, the same instance modified in-place with USP-patched forward and attention methods (for DiT) and DP-patched encode/decode methods (for VAE). Raises: RuntimeError: If the model type is not supported by lite_boost. Examples: >>> import os >>> os.environ["RANK"] = "0" >>> os.environ["WORLD_SIZE"] = "1" >>> from lite_boost.parallel import initialize_usp, ParallelManager >>> from wan.textimage2video import WanTI2V >>> initialize_usp() >>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...) >>> ParallelManager(pipe) """ def __new__(cls, target): from lite_boost.model import setup_model setup_model(target) return target