代码
HyperOffload:设备内存优化技术深度解析

HyperOffload:设备内存优化技术深度解析

HyperOffload:设备内存优化技术深度解析

01 核心问题与解决思路

1.1 显存受限的本质问题

在模型训练与推理过程中,模型参数与中间 Activation 的显存占用往往超出设备显存容量限制。传统方案要求用户手动管理数据 placement 或使用模型并行分割策略,开发门槛较高且灵活性不足。

1.2 数据 Offload 优化策略

HyperOffload 采用"数据换时间"的优化思路:将不常用的数据暂时移到主机端,需要时再移回来。通过这种方式,虽然增加了数据传输开销,但显著降低了峰值显存占用,使原本无法在有限显存环境下执行的大模型得以运行。

1.3 图优化层三大任务

图优化层承担三项核心职责:首先是分析计算图,确定哪些数据适合 Offload;其次是决定 Offload 时机,包括何时将数据从设备移到主机、何时从主机移回设备;最后是修改计算图,插入必要的 Offload 操作节点,并调整执行顺序以实现计算与传输的重叠执行。

02 图优化层代码逻辑详解

2.1 优化器主流程

HyperOffloadOptimizer::Run() 是图优化层的核心入口,其执行逻辑遵循清晰的流水线式设计。

执行 SkipHyperOffloadOptimizer() 完成配置检查后,operations_.Init() 初始化操作管理队列,为每个执行节点预留操作存储空间。随后根据使能状态分别处理 Parameter Offload 和 Activation Offload 两类场景。

GenerateParameterOperations() 负责为远程 Parameter 在恰当位置插入加载与释放操作,GenerateActivationOperations() 则通过分析节点依赖关系,识别适合 Offload 的中间结果并插入对应的传输节点。

在完成节点插入后,AdjustHyperOffloadNodePosition() 调整 H2D 节点的物理位置,根据 prefetch_distance 配置让数据加载操作适当提前,使后续计算能够与数据传输并行执行。接下来 AddEventNodes() 注入 Send/Recv 同步事件,确保数据传输与计算执行之间的正确依赖关系。

最后,BuildExecutionOrder() 将所有 Offload 操作节点按照依赖关系整合进原始执行顺序,形成包含新增节点的完整执行序列;AssignHyperOffloadIds() 为所有 Offload 相关节点分配唯一标识符,用于运行时追踪数据流转状态。

2.2 GenerateParameterOperations 伪代码

# 为远程 Parameter 插入 H2D/D2H 操作节点
function GenerateParameterOperations():
    for each node in exec_order:
        # 遍历节点的所有输入
        for i from 0 to node.inputs().size() - 1:
            input_node = node.input(i)
            
            # 处理 Load 节点包裹的 Parameter
            if IsPrimitiveCNode(input_node, "Load"):
                real_input = GetKernelWithReturnType(input_node, 0, false, "Load").first
                if IsPrimitiveCNode(real_input, "Load"):
                    data_node = real_input.input(1)  # Load 节点的第二个输入
                    if IsRemoteParameter(data_node):  # 检查是否是远程 Parameter
                        # 插入 H2D 节点
                        to_device = BuildToDeviceNode(data_node)
                        UpdateNodeInput(node, i, to_device)
                        operations_.AddOperationBefore(node_index, to_device)
                        
                        # 插入 D2H 节点(原地更新)
                        to_remote = BuildInplaceToHostNode(to_device, data_node, node)
                        operations_.AddOperationAfter(node_index, to_remote)
                continue
            
            # 直接处理 Parameter 引用
            if IsRemoteParameter(input_node):
                # 插入 H2D 节点
                to_device = BuildToDeviceNode(input_node)
                UpdateNodeInput(node, i, to_device)
                operations_.AddOperationBefore(node_index, to_device)
               
                # 插入 D2H 节点(原地更新)
                param_node = GetRemoteParameter(graph, input_node)
                to_remote = BuildInplaceToHostNode(to_device, param_node, node)
                operations_.AddOperationAfter(node_index, to_remote)

2.3 GenerateActivationOperations 伪代码

# 识别需 Offload 的 Activation 并插入 D2H/H2D 节点
function GenerateActivationOperations():
    # 1. 建立数据 → 使用者映射关系
    user_info_list = CollectAllNodeUsers(exec_order)
    
    # 2. 基于距离策略筛选 Offload 目标
    strategy = DistanceBaseHyperOffloadStrategy()
    offload_info_list = strategy.Run(exec_order, user_info_list)
    
    # 3. 数量过滤
    filter = OffloadInfoFilterByNumber()
    offload_info_list = filter.Filter(offload_info_list)
    
    # 4. 获取图输出节点
    outputs = GetAllOutputWithIndex(graph.output())
    
    # 5. 为每个目标插入 Offload 节点
    for each offload_info in offload_info_list:
        AddSingleActivationOperations(offload_info, outputs)

# 单个 Activation 的 Offload 节点插入
function AddSingleActivationOperations(offload_info, outputs):
    data_node = offload_info.data_node
    
    # 1. 处理 Tuple 类型的输出
    if data_node.abstract is Tuple:
        index_node = NewValueNode(offload_info.data_node.index)
        data_node = graph.NewCNode([TupleGetItem, data_node, index_node])
    
    # 2. 在数据产生位置之后插入 D2H 节点
    to_host = BuildToHostNode(data_node)
    data_idx = FindIndexInExecOrder(data_node)
    operations_.AddOperationAfter(data_idx, to_host)
    
    # 3. 为每个使用位置插入 H2D 节点
    for each replace_info in offload_info.replace_info_list:
        to_device = BuildToDeviceNode(to_host)
        
        # 替换原节点输入
        for each change_node in replace_info.replace_rest_nodes:
            UpdateNodeInput(change_node, change_node.index, to_device)
        
        # 在使用位置之前插入 H2D 节点
        change_idx = FindIndexInExecOrder(change_node)
        operations_.AddOperationBefore(change_idx, to_device)
    
    # 4. 处理输出节点
    if data_node in outputs:
        to_device = BuildToDeviceNode(to_host)
        UpdateNodeInput(output, output.index, to_device)
        operations_.AddOperationAfter(last_index, to_device)

2.4 AdjustHyperOffloadNodePosition 伪代码

# 根据 prefetch_distance 调整 H2D 节点位置,实现数据预取
function AdjustHyperOffloadNodePosition():
    # 1. 获取所有已插入的 H2D 节点
    h2d_nodes = operations_.GetAllH2DNodes()
 
    for each h2d_node in h2d_nodes:
        # 2. 获取 H2D 节点关联的原始数据节点
        data_index = 1  # 输入参数位置
        input_node = h2d_node.input(data_index)
 
        if IsD2HNode(input_node):
            # H2D 节点的输入是 D2H 节点,取 D2H 的输入作为原始数据
            data_node = input_node.input(data_index)
        else:
           data_node = input_node
 
        # 3. 获取 prefetch_distance(优先使用节点属性,其次使用全局配置)
        prefetch_distance = GetNodePrefetchDistance(data_node) or
GLOBAL_CONFIG.prefetch_distance
 
        # 4. 将 H2D 节点向前移动指定距离,实现预取
        operations_.MoveAhead(h2d_node, prefetch_distance)

2.5 BuildExecutionOrder 伪代码

# 将 Offload 操作节点插入到执行顺序中,构建新执行序列
function BuildExecutionOrder():
    new_execution_order = []
    
    # 1. 添加前置操作(独立于任何执行节点的操作)
    for op in operations_.GetPreOperations():
        new_execution_order.append(op)
    
    # 2. 遍历原始执行顺序,交错插入 Offload 操作
    for i from 0 to len(exec_order) - 1:      
        # 2.1 添加原始执行节点
        new_execution_order.append(exec_order[i])
        
        # 2.2 添加该位置对应的 Offload 操作(在 exec_order[i] 之后执行)
        for op in operations_.GetOperations()[i]:
            new_execution_order.append(op)
    
    return new_execution_order
# 示例:Activation Offload 的节点插入
# 原始 exec_order: [A, B, C, D, E]
# 假设 A 的输出 X 被 E 使用,且 A 与 E 之间距离超过阈值,需要 Offload
#
# 节点关系:
#   A(data: X) → E 使用 X
#   D2H 插入在 A 之后,H2D 插入在 E 之前
#
# operations_ 队列:
#   pre_operations: []  // 无前置操作
#   operations[0]: [D2H1]     // A 之后:D2H1(input: A)
#   operations[1]: []         // B 之后
#   operations[2]: []         // C 之后
#   operations[3]: []         // D 之后
#   operations[4]: [H2D1]     // E 之前:H2D1(input: D2H1)
#
# 构建结果: [A, D2H1(A), B, C, D, E, H2D1(D2H1)]
#
# 数据流向:
#   A(X) → D2H1(输入 A 的输出 X,输出 X 到主机) → E(从主机加载 X)
#                          ↘
#                           H2D1(输入 D2H1 的输出 X)

2.6 关键函数简述

Activation Offload 的处理流程包含四个关键环节。首先,CollectAllNodeUsers() 遍历计算图建立数据与使用者的映射关系,明确每个数据节点被哪些操作消费以及消费的先后顺序。然后 DistanceBaseHyperOffloadStrategy::Run() 基于使用距离判定 Offload 必要性:当数据产生位置与后续使用位置之间的间隔超过 select_distance 阈值时,判定该数据适合 Offload。判定结果经 OffloadInfoFilterByNumber::Filter() 数量过滤后,最终 AddSingleActivationOperations() 执行实际的节点插入操作:在数据产生处插入 D2H 节点将数据送回主机,在各使用处插入 H2D 节点重新加载数据。

事件注入环节,AddEventNodes() 为每个 D2H 和 H2D 节点配对 Send/Recv 事件,通过在不同执行流上调度发送与接收操作,实现计算与传输的时间重叠。

03 关键数据结构汇总

HyperOffloadInput — 优化器输入结构,包含待优化的计算图、原始执行顺序、Parameter Offload 开关、Activation Offload 开关、以及新节点回调函数。

HyperOffloadPlan — 优化器输出结构,包含优化后的计算图、包含 Offload 节点的新执行顺序、以及 HyperOffloadOperations 操作管理队列。

UserInfo — 封装数据节点与其全部消费者的映射关系,node 字段标识数据生产者,users 字段存储按执行顺序排列的消费者列表。

OffloadInfo — 描述单个数据的 Offload 方案,data_node 标识目标数据,replace_info_list 包含所有需要重写数据引用的节点信息。

HyperOffloadOperations — Offload 操作管理类,维护前置操作队列和每个执行位置对应的操作列表,支持 AddOperationBefore、AddOperationAfter、MoveAhead 等操作。

04 Python 使用示例

Python 层通过 @jit 装饰器的 auto_offload 参数暴露使用接口。Activation Offload 使用 @jit(auto_offload="activation") 启用,适用于网络中存在跨长距离复用的中间激活值场景。

from mindspore import jit

@jit(auto_offload="activation")
def forward(x):
    m1 = x / 2
    m3 = m1 * 2
    m4 = m3 * 2
    m5 = m3 + m4
    return (m5 * 2) - m1

运行时参数可通过 mindspore.graph.compile_config 模块调整,包括 SELECT_DISTANCE(距离阈值)、SELECT_NUM(Offload 上限)、PREFETCH_DISTANCE(预取距离)、RELEASE_DISTANCE(释放距离)等配置项。

05 总结与适用场景

HyperOffload 聚焦于 Ascend 设备显存受限场景下的模型高效执行,其基于数据流分析的编译期优化与基于异步事件的运行时协同机制,为大模型训练与内存受限推理提供了切实可行的解决方案。该技术适用于参数量庞大的 Transformer 类模型、长序列处理任务以及显存容量有限的边缘推理场景,核心价值体现在透明集成性、灵活配置性以及异步并行带来的性能收益。后续可探索基于机器学习的自适应策略选择、跨设备分布式 Offload 协同等优化方向。

关于HyperOffload的更多链接,请参考

https://atomgit.com/mindspore/mindspore/tree/master/mindspore/ccsrc/utils/hyper_offload

https://mp.weixin.qq.com/s?__biz=MzkxMTM2MjMzNg==&mid=2247634387&idx=1&sn=4f92b52ce05f25ad4604a56aef9e7484&scene=21&poc_token=HDIOhGmjOQivkVwCNZdvzEy_hINwjwgoJuTr1uka