{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用流程控制语句\n", "\n", "`Ascend` `GPU` `CPU` `模型开发`\n", "\n", "[![在线运行](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX2NvbnRyb2xfZmxvdy5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![下载Notebook](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_control_flow.ipynb) [![下载样例代码](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_control_flow.py) [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.6/docs/mindspore/programming_guide/source_zh_cn/control_flow.ipynb)\n", "\n", "## 概述\n", "\n", "MindSpore流程控制语句的使用与Python原生语法相似,尤其是在`PYNATIVE_MODE`模式下, 与Python原生语法基本一致,但是在`GRAPH_MODE`模式下,会有一些特殊的约束。鉴于以上原因,下述关于流程控制语句的使用指导均是指在`GRAPH_MODE`模式下运行。\n", "\n", "使用流程控制语句时,MindSpore会依据条件是否为变量来决策是否在网络中生成控制流算子,只有条件为变量时网络中才会生成控制流算子。如果条件表达式的结果需要在图编译时确定,则条件为常量,否则条件为变量。需要特殊说明的是,当网络中存在控制流算子时,网络会被切分成多个执行子图,子图间进行流程跳转和数据传递会引起一定的性能损耗。\n", "\n", "条件为变量的场景:\n", "\n", "- 条件表达式中存在Tensor或者元素为Tensor类型的List、Tuple、Dict,并且条件表达式的结果受Tensor的值影响。\n", "\n", "常见的变量条件:\n", "\n", "- `(x < y).all()`,`x`或`y`为算子输出。此时条件是否为真取决于算子输出`x`、`y`,而算子输出是图在各个step执行时才能确定。\n", "\n", "- `x in list`,`x`为算子输出。\n", "\n", "条件为常量的场景:\n", "\n", "- 条件表达式中不存在Tensor和元素为Tensor类型的List、Tuple、Dict。\n", "\n", "- 条件表达式中存在Tensor或者元素为Tensor类型的List、Tuple、Dict,但是表达式结果不受Tensor的值影响。\n", "\n", "常见的常量条件:\n", "\n", "- `self.flag`,`self.flag`为标量。此处`self.flag`为一个bool类型标量,其值在构建Cell对象时已确定,因此该条件是一个常量条件。\n", "\n", "- `x + 1 < 10`,`x`为标量。此处`x + 1`的值在构建Cell对象时是不确定的,但是在图编译时MindSpore会计算所有标量表达式的结果,因此该表达式的值也是在编译期确定的,该条件为常量条件。\n", "\n", "- `len(my_list) < 10`,`my_list为`元素是Tensor类型的List对象。虽然该条件表达式包含Tensor,但是表达式结果不受Tensor的值影响,只与`my_list`中Tensor的数量有关,因此该条件为常量条件。\n", "\n", "- `for i in range(0,10)`,`i`为标量,潜在的条件表达式`i < 10`为常量条件。\n", "\n", "## 使用if语句\n", "\n", "使用`if`语句需要注意在条件为变量时,在不同分支中的同一变量名应被赋予相同的数据类型。同时,网络最终生成的执行图的子图数量与`if`的数量成正比关系,过多的`if`会产生较大的控制流算子性能开销和子图间数据传递性能开销。\n", "\n", "### 使用条件为变量的if语句\n", "\n", "在例1中,`out`在true分支被赋值为[0],在false分支被赋值为[0, 1],条件`x < y`为变量,因此在`out = out + 1`这一句无法确定输入`out`的数据类型,会导致图编译出现异常。\n", "\n", "例1:" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2021-12-30T02:02:29.064538Z", "start_time": "2021-12-30T02:02:25.061791Z" } }, "source": [ "```python\n", "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class SingleIfNet(nn.Cell):\n", " def construct(self, x, y, z):\n", " if x < y:\n", " out = x\n", " else:\n", " out = z\n", " out = out + 1\n", " return out\n", "\n", "forward_net = SingleIfNet()\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "z = Tensor(np.array([0, 1]), dtype=ms.int32)\n", "output = forward_net(x, y, z)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例1报错信息如下:\n", "\n", "```text\n", "ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (2), shape2 = ()..\n", "```\n", "\n", "### 使用条件为常量的if语句\n", "\n", "在例2中,`out`在true分支被赋值为标量0,在false分支被赋值为[0, 1],`x`和`y`均为标量,条件`x < y + 1`为常量,图编译阶段可以确定是走true分支,因此网络中只存在true分支的内容并且无控制流算子,`out = out + 1`的输入`out`数据类型是确定的,因此该用例可正常执行。\n", "\n", "例2:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:40.804471Z", "start_time": "2022-01-07T03:35:39.226569Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output: 1\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class SingleIfNet(nn.Cell):\n", " def construct(self, z):\n", " x = 0\n", " y = 1\n", " if x < y + 1:\n", " out = x\n", " else:\n", " out = z\n", " out = out + 1\n", " return out\n", "\n", "forward_net = SingleIfNet()\n", "z = Tensor(np.array([0, 1]), dtype=ms.int32)\n", "output = forward_net(z)\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用for语句\n", "\n", "`for`语句会展开循环体内容。在例3中,`for`循环了3次,与例4最终生成的执行图结构是完全一致的,因此使用`for`语句的网络的子图数量、算子数量取决于`for`的迭代次数,算子数量过多或者子图过多会导致硬件资源受限。`for`语句导致出现子图过多的问题时,可参考`while`写作方式,尝试将`for`语句等价转换为条件是变量的`while`语句。\n", "\n", "例3:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:40.922574Z", "start_time": "2022-01-07T03:35:40.806485Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class IfInForNet(nn.Cell):\n", " def construct(self, x, y):\n", " out = 0\n", " for i in range(0, 3):\n", " if x + i < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", " return out\n", "\n", "forward_net = IfInForNet()\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y)\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例4:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:40.985953Z", "start_time": "2022-01-07T03:35:40.923594Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class IfInForNet(nn.Cell):\n", " def construct(self, x, y):\n", " out = 0\n", " #######cycle 0\n", " if x + 0 < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", "\n", " #######cycle 1\n", " if x + 1 < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", "\n", " #######cycle 2\n", " if x + 2 < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", " return out\n", "\n", "forward_net = IfInForNet()\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y)\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用while语句\n", "\n", "`while`语句相比`for`语句更为灵活。当`while`的条件为常量时,`while`对循环体的处理和`for`类似,会展开循环体里的内容。当`while`的条件为变量时,`while`不会展开循环体例的内容,则会在执行图产生控制流算子。\n", "\n", "### 使用条件为常量的while语句\n", "\n", "如例5所示,条件`i < 3`为常量, `while`的循环体的内容会被复制3份,因此最终生成的执行图和例4完全一致。`while`语句条件为常量时,算子数量和子图数量与while的循环次数成正比,算子数量过多或者子图过多会导致硬件资源受限。\n", "\n", "例5:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:41.051756Z", "start_time": "2022-01-07T03:35:40.988065Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class IfInWhileNet(nn.Cell):\n", " def construct(self, x, y):\n", " i = 0\n", " out = x\n", " while i < 3:\n", " if x + i < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", " i = i + 1\n", " return out\n", "\n", "forward_net = IfInWhileNet()\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y)\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 使用条件为变量的while语句\n", "\n", "如例6所示,把`while`条件变更为变量,`while`按照不展开处理,最终网络输出结果和例5一致,但执行图的结构不一致。例6不展开的执行图,有较少的算子和较多的子图,会使用较短的编译时间和占用较小的设备内存,但是会产生额外的控制流算子执行和子图间数据传递引起的性能开销。\n", "\n", "例6:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:41.100239Z", "start_time": "2022-01-07T03:35:41.053312Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class IfInWhileNet(nn.Cell):\n", " def construct(self, x, y, i):\n", " out = x\n", " while i < 3:\n", " if x + i < y:\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", " i = i + 1\n", " return out\n", "\n", "forward_net = IfInWhileNet()\n", "i = Tensor(np.array(0), dtype=ms.int32)\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y, i)\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当`while`的条件为变量时,`while`循环体不能展开,`while`循环体内的表达式都是在各个step运行时计算,因此循环体内部不能出现标量、List、Tuple等非Tensor类型的计算操作,这些类型计算操作需要在图编译时期完成,与`while`在执行期进行计算的机制是矛盾的。如例7所示,条件`i < 3`是变量条件,但是循环体内部存在`j = j + 1`的标量计算操作,最终会导致图编译出错。\n", "\n", "例7:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore import dtype as ms\n", "\n", "class IfInWhileNet(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.nums = [1, 2, 3]\n", "\n", " def construct(self, x, y, i):\n", " j = 0\n", " out = x\n", " while i < 3:\n", " if x + i < y :\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + self.nums[j]\n", " i = i + 1\n", " j = j + 1\n", " return out\n", "\n", "forward_net = IfInWhileNet()\n", "i = Tensor(np.array(0), dtype=ms.int32)\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y, i)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例7报错信息如下:\n", "\n", "```text\n", "IndexError: mindspore/core/abstract/prim_structures.cc:178 InferTupleOrListGetItem] list_getitem evaluator index should be in range[-3, 3), but got 3.\n", "```\n", "\n", "`while`条件为变量时,循环体内部不能更改算子的输入shape。因为MindSpore要求网络的同一个算子的输入shape在图编译时是确定的,而在`while`的循环体内部改变算子输入shape的操作是在图执行时生效,两者是矛盾的。如例8所示,条件`i < 3`为变量条件,`while`不展开,循环体内部的`ExpandDims`算子会改变表达式`out = out + 1`在下一轮循环的输入shape,会导致图编译出错。\n", "\n", "例8:\n", "\n", "```python\n", "import numpy as np\n", "from mindspore import Tensor, nn\n", "from mindspore.common import dtype as ms\n", "from mindspore import ops\n", "\n", "class IfInWhileNet(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.expand_dims = ops.ExpandDims()\n", "\n", " def construct(self, x, y, i):\n", " out = x\n", " while i < 3:\n", " if x + i < y :\n", " out = out + x\n", " else:\n", " out = out + y\n", " out = out + 1\n", " out = self.expand_dims(out, -1)\n", " i = i + 1\n", " return out\n", "\n", "forward_net = IfInWhileNet()\n", "i = Tensor(np.array(0), dtype=ms.int32)\n", "x = Tensor(np.array(0), dtype=ms.int32)\n", "y = Tensor(np.array(1), dtype=ms.int32)\n", "output = forward_net(x, y, i)\n", "```\n", "\n", "例8报错信息如下:\n", "\n", "```text\n", "ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (1, 1), shape2 = (1)..\n", "```\n", "\n", "## 使用while语句等价替换for语句\n", "\n", "`for`语句会展开循环体内容,带来执行性能提升的同时,也会带来编译问题,如增加编译时间、函数调用深度溢出、不同的权重被共享等等。为了解决该类问题,需要将`for`语句等价转换为`while`语句。\n", "\n", "### 简单示例\n", "\n", "如例9所示,实现了通过加法来完成两数相乘的计算。\n", "\n", "例9:\n", "\n", "```python\n", "from mindspore import Tensor\n", "from mindspore import ms_function\n", "\n", "one = Tensor(1)\n", "zero = Tensor(0)\n", "\n", "@ms_function\n", "def mul_by_for(x, y):\n", " r = zero\n", " for _ in range(y):\n", " r = r + x\n", " return r\n", "\n", "\n", "a = Tensor(2)\n", "b = 1000\n", "out = mul_by_for(a, b)\n", "print(out)\n", "```\n", "\n", "执行出错,报错信息如下所示:\n", "\n", "```text\n", "RuntimeError: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc:201 Eval] Exceed function call depth limit 1000, (function call depth: 1001, simulate call depth: 998).\n", "It's always happened with complex construction of code or infinite recursion or loop.\n", "Please check the code if it's has the infinite recursion or call 'context.set_context(max_call_depth=value)' to adjust this value.\n", "If max_call_depth is set larger, the system max stack depth should be set larger too to avoid stack overflow.\n", "```\n", "\n", "这是因为`for`语句内的循环体会被展开。该例子中循环了1000次,展开后的子图过多,函数调用深度超出了允许的限制。为了不让子图数量展开太多,可等价替换成`while`实现,实现方式如例10所示。\n", "\n", "例10:\n", "\n", "```python\n", "from mindspore import Tensor\n", "from mindspore import ms_function\n", "\n", "one = Tensor(1)\n", "zero = Tensor(0)\n", "\n", "@ms_function\n", "def mul_by_while(x, y):\n", " y = Tensor(y)\n", " r = zero\n", " while y > 0:\n", " y = y - one\n", " r = r + x\n", " return r\n", "\n", "a = Tensor(2)\n", "b = 1000\n", "out = mul_by_while(a, b)\n", "print(out)\n", "```\n", "\n", "执行结果正确,如下:\n", "\n", "```text\n", "2000\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 循环体内有权重\n", "\n", "如例11所示,实现了`1+2+3`的计算,同时用权重保存循环中每次迭代的计数器的值。\n", "\n", "例11:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:41.171702Z", "start_time": "2022-01-07T03:35:41.101212Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.0\n" ] } ], "source": [ "import mindspore\n", "from mindspore import nn, Tensor\n", "from mindspore import Parameter\n", "\n", "class AddIndexNet(nn.Cell):\n", " def __init__(self, index):\n", " super(AddIndexNet, self).__init__()\n", " self.weight = Parameter(Tensor(0, mindspore.float32), name=\"weight\")\n", " self.idx = Tensor(index)\n", "\n", " def construct(self, x):\n", " self.weight = self.weight + self.idx\n", " x = x + self.weight\n", " return x\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.idx = Tensor(0)\n", " self.block_nums = 3\n", " self.nets = []\n", " for i in range(self.block_nums):\n", " self.nets.append(AddIndexNet(i + 1))\n", "\n", " def construct(self, x):\n", " for i in range(self.block_nums):\n", " x = self.nets[i](x)\n", " return x\n", "\n", "x = Tensor(0, mindspore.float32)\n", "net = Net()\n", "out = net(x)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "结果和预期不符,预期结果应该是`1.0+2.0+3.0=6.0`。这是因为循环体展开后,不同迭代中的算子共享了同一个权重(即self.weight),导致每次迭代,更新的都是同一个权重。\n", "\n", "为了解决该问题,我们使用`while`语句等价替换该`for`语句,如例12所示。\n", "\n", "例12:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-01-07T03:35:41.220307Z", "start_time": "2022-01-07T03:35:41.172728Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[6.]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "from mindspore import nn, Tensor, ops\n", "from mindspore import Parameter\n", "\n", "class AddIndexNet(nn.Cell):\n", " def __init__(self, block_nums):\n", " super(AddIndexNet, self).__init__()\n", " self.weights = Parameter(Tensor(np.zeros((block_nums, 1)), mindspore.float32), name=\"weights\")\n", " self.gather = ops.Gather()\n", "\n", " def construct(self, x, index):\n", " weight = self.gather(self.weights, index, 0)\n", " weight += (index + 1)\n", " x = x + weight\n", " return x\n", "\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.idx = Parameter(Tensor(0), name=\"index\")\n", " self.iter_num = 3\n", " self.add_net = AddIndexNet(self.iter_num)\n", "\n", " def construct(self, x):\n", " while self.idx < self.iter_num:\n", " x = self.add_net(x, self.idx)\n", " self.idx += 1\n", " return x\n", "\n", "\n", "x = Tensor([0], mindspore.float32)\n", "net = Net()\n", "out = net(x)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在该例中,拓展了权重的维度为`[iter_num, 1]`,不同迭代中即使共享同一个权重,但第0维之外的数据相对独立,使用时再通过`Gather`算子取出对应的数据,完成相应的计算。预期的结果是`1+2+3=6`,本例中打印的结果值符合预期。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 约束\n", "\n", "当前使用流程语句除了条件变量场景下的约束,还有一些其他特定场景下的约束。\n", "\n", "### 动态shape约束\n", "\n", "当网络模型中存在动态shape时,禁止使用条件为变量的流程控制语句,否则可能会出现未知异常。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }