lite_boost.parallel.ParallelManager

查看源文件
class lite_boost.parallel.ParallelManager(target)[源代码]

对支持的模型进行原地修改,使其支持分布式并行推理。

ParallelManager 封装一个支持的模型或流水线对象,并对其进行原地补丁替换,以实现多NPU设备的并行推理。根据检测到的模型组件,自动应用以下两种并行策略:

  • Ulysses序列并行(USP) 用于DiT模型 - 补丁替换 forward 方法和注意力层,通过 all_to_all 通信实现序列维度并行,每张卡持有完整模型权重,仅对序列的一个切片进行计算。

  • 数据并行(DP)时间切片 用于VAE模型 - 将 vae.encodevae.decode 替换为DP时间切片版本,沿时间维度将视频切分为重叠的帧片段,分发到各卡独立处理,最后收集拼接为完整结果。

当传入流水线对象(如 WanT2VWanTI2V)时,两种策略同时生效,DiT模型应用USP,VAE应用DP。

模型在原地修改后原样返回,因此所有已有的属性和方法( .to.cpu.eval 等)均可正常使用。

内部的补丁替换由 lite_boost.model.setup_model() 分发执行,该函数自动检测模型类型并应用对应的适配器(例如,将 flash_attention 替换为NPU兼容版本,在注意力层前后插入 all_to_all 通信对,以及将DP时间切片绑定到VAE的encode/decode方法)。

参数:
  • target (object) – 需要并行化的支持流水线对象,模型类型通过 lite_boost.model.setup_model() 自动检测,支持的类包括 WanT2VWanTI2V

返回:

object,与输入相同的实例,已原地修改为USP补丁后的forward和注意力方法(DiT)以及DP补丁后的encode/decode方法(VAE)。

异常:
  • RuntimeError - 模型类型不被lite_boost支持时抛出。

样例:

>>> import os
>>> os.environ["RANK"] = "0"
>>> os.environ["WORLD_SIZE"] = "1"
>>> from lite_boost.parallel import initialize_usp, ParallelManager
>>> from wan.textimage2video import WanTI2V
>>> initialize_usp()
>>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...)
>>> ParallelManager(pipe)