{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 运算重载\n", "\n", "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMS43L3R1dG9yaWFscy9leHBlcnRzL3poX2NuL29wZXJhdGlvbi9taW5kc3BvcmVfb3Bfb3ZlcmxvYWQuaXB5bmI=&imageid=9d63f4d1-dc09-4873-b669-3483cea777c0) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.7/tutorials/experts/zh_cn/operation/mindspore_op_overload.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.7/tutorials/experts/zh_cn/operation/mindspore_op_overload.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.7/tutorials/experts/source_zh_cn/operation/op_overload.ipynb)\n", "\n", "## 概述\n", "\n", "`mindspore.ops.composite`中提供了一些涉及图变换的组合类算子,例如`MultitypeFuncGraph`、`HyperMap`等。\n", "\n", "## MultitypeFuncGraph\n", "\n", "`MultitypeFuncGraph`用于生成重载函数,支持不同类型的输入。用户可以使用`MultitypeFuncGraph`定义一组重载的函数,根据不同类型,采用不同实现。首先初始化一个`MultitypeFuncGraph` 对象,使用带有输入类型的 `register` 作为待注册函数的装饰器,使得该对象支持多种类型的输入。更多使用方法见:[MultitypeFuncGraph](https://www.mindspore.cn/docs/zh-CN/r1.7/api_python/ops/mindspore.ops.MultitypeFuncGraph.html)。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-12-31T08:28:36.365991Z", "start_time": "2021-12-31T08:28:34.387953Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor [[2.4 4.2]\n", " [4.4 6.4]]\n", "scalar 3\n" ] } ], "source": [ "import numpy as np\n", "from mindspore.ops import MultitypeFuncGraph\n", "from mindspore import Tensor\n", "import mindspore.ops as ops\n", "\n", "add = MultitypeFuncGraph('add')\n", "@add.register(\"Number\", \"Number\")\n", "def add_scalar(x, y):\n", " return ops.scalar_add(x, y)\n", "\n", "@add.register(\"Tensor\", \"Tensor\")\n", "def add_tensor(x, y):\n", " return ops.tensor_add(x, y)\n", "\n", "tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))\n", "tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))\n", "print('tensor', add(tensor1, tensor2))\n", "print('scalar', add(1, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## HyperMap\n", "\n", "`HyperMap`可以对一组或多组输入做指定的运算,可以配合`MultitypeFuncGraph`一起使用。例如定义一组重载的`add`函数后,对多组不同类型的输入进行`add`运算。不同于`Map`,`HyperMap` 能够用于嵌套结构,对序列或嵌套序列中的输入做指定运算。更多使用方法见:[HyperMap](https://www.mindspore.cn/docs/zh-CN/r1.7/api_python/ops/mindspore.ops.HyperMap.html)。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-12-31T08:28:36.379981Z", "start_time": "2021-12-31T08:28:36.368396Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output = (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 6), 3)\n" ] } ], "source": [ "from mindspore import dtype as mstype\n", "from mindspore import Tensor\n", "from mindspore.ops import MultitypeFuncGraph, HyperMap\n", "import mindspore.ops as ops\n", "\n", "add = MultitypeFuncGraph('add')\n", "@add.register(\"Number\", \"Number\")\n", "def add_scalar(x, y):\n", " return ops.scalar_add(x, y)\n", "\n", "@add.register(\"Tensor\", \"Tensor\")\n", "def add_tensor(x, y):\n", " return ops.tensor_add(x, y)\n", "\n", "add_map = HyperMap(add)\n", "output = add_map((Tensor(1, mstype.float32), Tensor(2, mstype.float32), 1), (Tensor(3, mstype.float32), Tensor(4, mstype.float32), 2))\n", "print(\"output =\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "此例子中传入`add_map`的输入包含了两个序列,`HyperMap`会以`operation(args[0][i], args[1][i])`的形式分别从两个序列中取相应的元素作为`add`函数的输入`x`和`y`,例如`add(Tensor(1, mstype.float32), Tensor(3, mstype.float32))`。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }