{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 调用自定义类\n", "\n", "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMS43L3R1dG9yaWFscy9leHBlcnRzL3poX2NuL2RlYnVnL21pbmRzcG9yZV9tc19jbGFzcy5pcHluYg==&imageid=9d63f4d1-dc09-4873-b669-3483cea777c0) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.7/tutorials/experts/zh_cn/debug/mindspore_ms_class.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.7/tutorials/experts/zh_cn/debug/mindspore_ms_class.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.7/tutorials/experts/source_zh_cn/debug/ms_class.ipynb)\n", "\n", "## 概述\n", "\n", "通过ms_class,用户可以在静态图模式下调用自定义类的属性和方法。\n", "\n", "在静态图模式下,用户需要获取自定义类的属性/方法时,可以对该类使用@ms_class装饰器,从而调用其属性/方法。在动态图模式即PyNative模式下,ms_class的使用也是支持的,但用户不需要@ms_class装饰器也能调用自定义类的属性和方法。\n", "\n", "本文档主要介绍ms_class的使用场景和使用须知,以便您可以更有效地使用ms_class功能。\n", "\n", "## 使用场景\n", "\n", "### 1、调用自定义类的属性\n", "\n", "调用自定义类的属性时,可以通过@ms_class装饰器,对自定义类进行修饰。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 2 3]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_class\n", "\n", "@ms_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.value = Tensor(np.array([1, 2, 3]))\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self):\n", " out = self.inner_net.value\n", " return out\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2、调用自定义类的方法\n", "\n", "调用自定义类的方法时,可以通过@ms_class装饰器,对自定义类进行修饰。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.498145Z", "start_time": "2022-01-04T11:36:31.408221Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5 7 9]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_class\n", "\n", "@ms_class\n", "class InnerNet:\n", " def act(self, x, y):\n", " return x + y\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self, x, y):\n", " out = self.inner_net.act(x, y)\n", " return out\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "x = Tensor(np.array([1, 2, 3]).astype(np.int32))\n", "y = Tensor(np.array([4, 5, 6]).astype(np.int32))\n", "net = Net()\n", "out = net(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3、调用嵌套的自定义类的属性和方法\n", "\n", "多个自定义类嵌套时,如果都使用了@ms_class装饰器,则可以获取嵌套类的属性和方法。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.517439Z", "start_time": "2022-01-04T11:36:31.499697Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 2 3]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_class\n", "\n", "@ms_class\n", "class Inner:\n", " def __init__(self):\n", " self.value = Tensor(np.array([1, 2, 3]))\n", "\n", "@ms_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.inner = Inner()\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self):\n", " out = self.inner_net.inner.value\n", " return out\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4、自定义类和nn.Cell嵌套使用\n", "\n", "当自定义类和nn.Cell嵌套使用时,调用自定义类的属性和方法。关于nn.Cell的介绍,请参考[mindspore.nn.Cell](https://www.mindspore.cn/docs/zh-CN/r1.7/api_python/nn/mindspore.nn.Cell.html)。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.517439Z", "start_time": "2022-01-04T11:36:31.499697Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import dtype as mstype\n", "from mindspore import context, Tensor, ms_class\n", "\n", "class Net(nn.Cell):\n", " def __init__(self, val):\n", " super().__init__()\n", " self.val = val\n", "\n", " def construct(self, x):\n", " return x + self.val\n", "\n", "@ms_class\n", "class TrainNet():\n", " class Loss(nn.Cell):\n", " def __init__(self, net):\n", " super().__init__()\n", " self.net = net\n", "\n", " def construct(self, x):\n", " out = self.net(x)\n", " return out * 2\n", "\n", " def __init__(self, net):\n", " self.net = net\n", " loss_net = self.Loss(self.net)\n", " self.number = loss_net(10)\n", "\n", "global_net = Net(1)\n", "class LearnNet(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.value = TrainNet(global_net).number\n", "\n", " def construct(self, x):\n", " return x + self.value\n", "\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "x = Tensor(3, mstype.int32)\n", "leanrn_net = LearnNet()\n", "out = leanrn_net(x)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用须知\n", "\n", "使用ms_class时,需要考虑以下条件:\n", "\n", "### 1、ms_class不支持非class类型\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "\n", "from mindspore import ms_class\n", "\n", "@ms_class\n", "def func(x, y):\n", " return x + y\n", "\n", "func(1, 2)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "执行代码后,将会提示以下报错信息:\n", "\n", "TypeError: Decorator ms_class can only be used for class type, but got .\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2、ms_class支持调用类实例的属性和方法,不支持直接从类定义获取其属性和方法,不支持在construct/ms_function函数中创建自定义类的实例。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "\n", "import mindspore.nn as nn\n", "from mindspore import context, ms_class\n", "\n", "@ms_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.number = 1\n", "\n", "class Net(nn.Cell):\n", " def construct(self):\n", " out = InnerNet().number\n", " return out\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "net = Net()\n", "net()\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "执行代码后,将会提示以下报错信息:\n", "\n", "ValueError: This may be not defined, or it can't be a operator. Please check code.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3、不支持对nn.Cell使用@ms_class装饰器。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_class\n", "\n", "@ms_class\n", "class Net(nn.Cell):\n", " def construct(self, x):\n", " return x\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "x = Tensor(1)\n", "net = Net()\n", "net(x)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "执行代码后,将会提示以下报错信息:\n", "\n", "TypeError: ms_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4、不支持调用自定义类的私有属性或魔术方法。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "\n", "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_class\n", "\n", "@ms_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.value = Tensor(np.array([1, 2, 3]))\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self):\n", " out = self.inner_net.__str__()\n", " return out\n", "\n", "context.set_context(mode=context.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "执行代码后,将会提示以下报错信息:\n", "\n", "AttributeError: `__str__` is a private variable or magic method, which is not supported.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }