运算重载

在线运行下载Notebook下载样例代码查看源文件

概述

mindspore.ops.composite中提供了一些涉及图变换的组合类算子,例如MultitypeFuncGraphHyperMap等。

MultitypeFuncGraph

MultitypeFuncGraph用于生成重载函数,支持不同类型的输入。用户可以使用MultitypeFuncGraph定义一组重载的函数,根据不同类型,采用不同实现。首先初始化一个MultitypeFuncGraph 对象,使用带有输入类型的 register 作为待注册函数的装饰器,使得该对象支持多种类型的输入。更多使用方法见:MultitypeFuncGraph

代码样例如下:

[1]:
import numpy as np
from mindspore.ops import MultitypeFuncGraph
from mindspore import Tensor
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.tensor_add(x, y)

tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
print('tensor', add(tensor1, tensor2))
print('scalar', add(1, 2))
tensor [[2.4 4.2]
 [4.4 6.4]]
scalar 3

HyperMap

HyperMap可以对一组或多组输入做指定的运算,可以配合MultitypeFuncGraph一起使用。例如定义一组重载的add函数后,对多组不同类型的输入进行add运算。不同于MapHyperMap 能够用于嵌套结构,对序列或嵌套序列中的输入做指定运算。更多使用方法见:HyperMap

代码样例如下:

[2]:
from mindspore import dtype as mstype
from mindspore import Tensor
from mindspore.ops import MultitypeFuncGraph, HyperMap
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.tensor_add(x, y)

add_map = HyperMap(add)
output = add_map((Tensor(1, mstype.float32), Tensor(2, mstype.float32), 1), (Tensor(3, mstype.float32), Tensor(4, mstype.float32), 2))
print("output =", output)
output = (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 6), 3)

此例子中传入add_map的输入包含了两个序列,HyperMap会以operation(args[0][i], args[1][i])的形式分别从两个序列中取相应的元素作为add函数的输入xy,例如add(Tensor(1, mstype.float32), Tensor(3, mstype.float32))