# 异构并行训练

## 计算流程

MindSpore异构并行训练典型的计算流程如下图所示：

1. 用户设置网络执行的后端

[ ]:

import mindspore as ms
ms.set_context(device_target="GPU")

1. 用户设置特定算子执行后端

[1]:

from mindspore import ops

prim.set_device("CPU")

1. 框架根据计算图算子标志进行切图

2. 框架调度不同后端执行子图

## 优化器异构

1. 配置优化器算子到CPU执行

2. 初始化FP16的权重参数以及FP32的优化器状态变量

3. 将输入优化器的梯度转为FP16（如果本来就是FP16梯度，可忽略这步）

4. 权重和梯度转为FP32参与优化器运算

5. 更新后的FP32权重赋值给FP16的权重

[2]:

import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore.common.initializer import initializer
from mindspore.nn import Optimizer
host_assign = ops.Assign()
host_assign.set_device("CPU")
host_cast = ops.Cast()
host_cast.set_device("CPU")
device_cast = ops.Cast()

@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
"""
"""
success = True
if optim_filter:
param32 = host_cast(param, ms.float32)
if decay_flags:
next_param = opt(param32, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
else:
next_param = opt(param32, m, v, lr, beta1, beta2, eps, 0.0, gradient)
ret = host_assign(param, host_cast(ops.depend(param32, next_param), ops.dtype(param)))
return ops.depend(success, ret)
return success

def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
self.beta1 = ms.Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = ms.Tensor(np.array([beta2]).astype(np.float32))
self.eps = ms.Tensor(np.array([eps]).astype(np.float32))
self.hyper_map = ops.HyperMap()
self.opt.set_device("CPU")

lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
else:
optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
else:
optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay), self.parameters, self.moments1, self.moments2,
return optim_result

def clone_param32(self, prefix, init=None):
new = []
for old_param in self.parameters:
param_init = init
if init is None:
param_init = old_param.init
new_state = old_param.clone()
new_state.set_dtype(ms.float32)
new_state.set_data(initializer(param_init, shape=old_param.shape, dtype=ms.float32))
new_state.name = prefix + '.' + new_state.name
new.append(new_state)
return ms.ParameterTuple(new)


## Embedding异构

1. 配置EmbeddingLookup算子到CPU执行

[3]:

ops.EmbeddingLookup().set_device('CPU')

1. 配置EmbeddingLookup相关优化器到CPU执行

[4]:

use_locking = False
use_nesterov = False


EmbeddingLookup算子设置代码样例如下：

[5]:

import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms
from mindspore.common.initializer import initializer

class EmbeddingLookup(nn.Cell):
def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', sparse=True):
"""Initialize EmbeddingLookup."""
super(EmbeddingLookup, self).__init__()
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.target = target
self.sparse = sparse
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
if sparse:
self.gatherv2 = ops.SparseGatherV2()
else:
self.gatherv2 = ops.Gather()
self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
self.embedding_table = ms.Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')

def construct(self, indices):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
return out


## PS异构

Parameter Server封装异构流程，用户只需配置参数使用PS即可，具体配置流程请参考Parameter Server训练流程