lite_boost.parallel.ParallelManager
- class lite_boost.parallel.ParallelManager(target)[源代码]
对支持的模型进行原地修改,使其支持分布式并行推理。
ParallelManager封装一个支持的模型或流水线对象,并对其进行原地补丁替换,以实现多NPU设备的并行推理。根据检测到的模型组件,自动应用以下两种并行策略:Ulysses序列并行(USP) 用于DiT模型 - 补丁替换
forward方法和注意力层,通过all_to_all通信实现序列维度并行,每张卡持有完整模型权重,仅对序列的一个切片进行计算。数据并行(DP)时间切片 用于VAE模型 - 将
vae.encode和vae.decode替换为DP时间切片版本,沿时间维度将视频切分为重叠的帧片段,分发到各卡独立处理,最后收集拼接为完整结果。
当传入流水线对象(如
WanT2V或WanTI2V)时,两种策略同时生效,DiT模型应用USP,VAE应用DP。模型在原地修改后原样返回,因此所有已有的属性和方法(
.to、.cpu、.eval等)均可正常使用。内部的补丁替换由
lite_boost.model.setup_model()分发执行,该函数自动检测模型类型并应用对应的适配器(例如,将flash_attention替换为NPU兼容版本,在注意力层前后插入all_to_all通信对,以及将DP时间切片绑定到VAE的encode/decode方法)。- 参数:
target (object) – 需要并行化的支持流水线对象,模型类型通过
lite_boost.model.setup_model()自动检测,支持的类包括WanT2V和WanTI2V。
- 返回:
object,与输入相同的实例,已原地修改为USP补丁后的forward和注意力方法(DiT)以及DP补丁后的encode/decode方法(VAE)。
- 异常:
RuntimeError - 模型类型不被lite_boost支持时抛出。
样例:
>>> import os >>> os.environ["RANK"] = "0" >>> os.environ["WORLD_SIZE"] = "1" >>> from lite_boost.parallel import initialize_usp, ParallelManager >>> from wan.textimage2video import WanTI2V >>> initialize_usp() >>> pipe = WanTI2V(config=cfg, checkpoint_dir=ckpt_dir, ...) >>> ParallelManager(pipe)