{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 参数传递\n", "\n", "`Ascend` `GPU` `CPU` `模型开发`\n", "\n", "[![在线运行](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX2luZGVmaW5pdGVfcGFyYW1ldGVyLmlweW5i&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![下载Notebook](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_indefinite_parameter.ipynb) [![下载样例代码](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_indefinite_parameter.py) [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.6/docs/mindspore/programming_guide/source_zh_cn/indefinite_parameter.ipynb)\n", "\n", "## 概述\n", "\n", "本文介绍不定参数在网络构建中的使用,指的是在构建网络时,可以使用不定个数的参数来构造,有两种构造方式,一种是直接传入一个tuple类型的参数,另一种是传入Python的可变参数(\\*参数)来构造。下面以两个例子来说明这两种构造方式的用法。\n", "\n", "## 传入tuple类型的参数\n", "\n", "构造一个单Add算子网络,该网络需要有两个输入,可以传入一个tuple类型的参数来代替两个输入。网络构造如下:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:27:38.930014Z", "start_time": "2022-01-04T06:27:37.268097Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore\n", "from mindspore import Tensor\n", "from mindspore.nn import Cell\n", "import mindspore.ops as op\n", "\n", "\n", "class AddModel(Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.add = op.Add()\n", "\n", " def construct(self, inputs):\n", " return self.add(inputs[0], inputs[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`AddModel`网络的定义中,inputs表示的是一个tuple类型的参数,包含两个元素。调用方法如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:27:38.966557Z", "start_time": "2022-01-04T06:27:38.932034Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2. 2. 2.]\n", " [2. 2. 2.]]\n" ] } ], "source": [ "input_x = Tensor(np.ones((2, 3)), mindspore.float32)\n", "input_y = Tensor(np.ones((2, 3)), mindspore.float32)\n", "net = AddModel()\n", "\n", "y = net((input_x, input_y))\n", "print(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 传入Python的可变参数\n", "\n", "第二种用法是使用Python的可变参数(\\*参数),网络构造如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:27:38.973200Z", "start_time": "2022-01-04T06:27:38.968135Z" } }, "outputs": [], "source": [ "class AddModel(Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.add = op.Add()\n", "\n", " def construct(self, *inputs):\n", " return self.add(inputs[0], inputs[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "第二种用法,网络定义中,\\*inputs表示的是Python中的可变参数,可以在函数定义时收集位置参数组成一个tuple对象,也可以在函数调用时解包tuple对象中的每个参数,调用方法有两种,如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:27:39.005207Z", "start_time": "2022-01-04T06:27:38.974205Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2. 2. 2.]\n", " [2. 2. 2.]]\n" ] } ], "source": [ "input_x = Tensor(np.ones((2, 3)), mindspore.float32)\n", "input_y = Tensor(np.ones((2, 3)), mindspore.float32)\n", "net = AddModel()\n", "\n", "#1) The first call method\n", "y = net(input_x, input_y)\n", "\n", "#2) The second call method\n", "inputs = (input_x, input_y)\n", "y = net(*inputs)\n", "\n", "print(y)" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }