# Using the Process Control Statement `Ascend` `GPU` `CPU` `Model Development` [![View Source On Gitee](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_source_en.png)](https://gitee.com/mindspore/docs/blob/r1.5/docs/mindspore/programming_guide/source_en/control_flow.md) ## Overview The MindSpore process control statement is similar to the native Python syntax, especially in `PYNATIVE_MODE` mode. However, there are some special constraints in `GRAPH_MODE` mode. The following process control statements are executed in `GRAPH_MODE` mode. When a process control statement is used, MindSpore determines whether to generate a control flow operator on a network based on whether the condition is a variable. The control flow operator is generated on the network only when the condition is a variable. If a condition expression result needs to be determined during graph build, the condition is a constant. Otherwise, the condition is a variable. It should be specially noted that, when a control flow operator exists in a network, the network is divided into multiple execution subgraphs, and process jumping and data transmission between the subgraphs cause performance loss to some extent. In the scenario where the condition is a variable: - The condition expression contains tensors or a list, tuple, or dict of the tensor type, and the condition expression result is affected by the tensor value. Common variable conditions are as follows: - `(x < y).all()`, where `x` or `y` is the operator output. In this case, whether the condition is true depends on the operator output `x` and `y`, and the operator output can be determined only when each step is executed. - `x in list`, where `x` is the operator output. In the scenario where the condition is a constant: - The condition expression does not contain tensors or a list, tuple, or dict of the tensor type. - The condition expression contains tensors or a list, tuple, or dict of the tensor type, but the condition expression result is not affected by the tensor value. Common constant conditions are as follows: - `self.flag`, which is a scalar of the Boolean type. The value of `self.flag` is determined when the cell object is created. Therefore, `self.flag` is a constant condition. - `x + 1 < 10`, where `x` is a scalar. Although the value of `x + 1` is uncertain when a cell object is created, MindSpore computes the results of all scalar expressions during graph build. Therefore, the expression value is determined during build and this is a constant condition. - `len(my_list) < 10`, where `my_list` is a list object of the tensor type. Although the condition expression contains tensors, the expression result is not affected by the tensor value and is related only to the number of tensors in `my_list`. Therefore, this is a constant condition. - `for i in range (0,10)`, where `i` is a scalar, and the potential condition expression `i < 10` is a constant condition. ## Using the if Statement When using the `if` statement, ensure that the same variable name in different branches is assigned the same data type if the condition is a variable. In addition, the number of subgraphs of the execution graph generated by the network is in direct proportion to the number of `if`. Too many `if` statements generate high performance overheads of the control flow operators and those of the subgraph data transmission. ### Using an if Statement with a Variable Condition In example 1, `out` is set to [0] in the true branch and to [0, 1] in the false branch. `x < y` is a variable. Therefore, the data type of `out` cannot be determined in the `out = out + 1` statement, causing a graph build exception. Example 1: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class SingleIfNet(nn.Cell): def construct(self, x, y, z): if x < y: out = x else: out = z out = out + 1 return out forward_net = SingleIfNet() x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) z = Tensor(np.array([0, 1]), dtype=ms.int32) output = forward_net(x, y, z) ``` The error information in example 1 is as follows: ```text ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (2), shape2 = ().. ``` ### Using an if Statement with a Constant Condition In example 2, `out` is assigned to scalar 0 in the true branch and is assigned to [0, 1] in the false branch. `x` and `y` are scalars, and `x < y + 1` is a constant. It can be determined that the true branch is used in the build phase; therefore, only the content of the true branch exists on the network and there is no control flow operator. The input `out` data type of `out = out + 1` is fixed. Therefore, the test case can be executed properly. Example 2: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class SingleIfNet(nn.Cell): def construct(self, z): x = 0 y = 1 if x < y + 1: out = x else: out = z out = out + 1 return out forward_net = SingleIfNet() z = Tensor(np.array([0, 1]), dtype=ms.int32) output = forward_net(z) ``` ## Using the for Statement The `for` statement expands the loop body. In example 3, `for` is cycled for three times, which is the same as the structure of the execution graph generated in example 4. Therefore, the number of subgraphs and operators of the network using the `for` statement depends on the number of `for` iterations. If there are too many operators or subgraphs, hardware resources are limited. If there are too many subgraphs due to the `for` statement, you can refer to the `while` writing mode and try to convert the `for` statement to the `while` statement whose condition is variable. Example 3: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class IfInForNet(nn.Cell): def construct(self, x, y): out = 0 for i in range(0,3): if x + i < y : out = out + x else: out = out + y out = out + 1 return out forward_net = IfInForNet() x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y) ``` Example 4: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class IfInForNet(nn.Cell): def construct(self, x, y): out = 0 #######cycle 0 if x + 0 < y : out = out + x else: out = out + y out = out + 1 #######cycle 1 if x + 1 < y : out = out + x else: out = out + y out = out + 1 #######cycle 2 if x + 2 < y : out = out + x else: out = out + y out = out + 1 return out forward_net = IfInForNet() x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y) ``` ## Using the while Statement The `while` statement is more flexible than the `for` statement. When the condition of `while` is a constant, `while` processes and expands the loop body in a similar way as `for`. When the condition of `while` is a variable, `while` does not expand the loop body. In this case, a control flow operator is generated when the graph is executed. ### Using a while Statement with a Constant Condition As shown in example 5, the condition `i < 3` is a constant, and the content of the `while` loop body is copied for three times. Therefore, the generated execution diagram is the same as that in example 4. When the `while` statement condition is a constant, the number of operators and subgraphs is proportional to the number of `while` loops. If there are too many operators or subgraphs, hardware resources are limited. Example 5: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class IfInWhileNet(nn.Cell): def construct(self, x, y): i = 0 out = x while i < 3: if x + i < y : out = out + x else: out = out + y out = out + 1 i = i + 1 return out forward_net = IfInWhileNet() x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y) ``` ### Using a while Statement with a Variable Condition As shown in example 6, the `while` condition is changed to a variable, and `while` is not expanded. The final network output result is the same as that in example 5, but the structure of the execution graph is different. In example 6, there are fewer operators and more subgraphs in an execution graph that is not expanded. A shorter build time and a smaller device memory are used, but extra performance overheads caused by execution of a control flow operator and data transfer between subgraphs are generated. Example 6: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class IfInWhileNet(nn.Cell): def construct(self, x, y, i): out = x while i < 3: if x + i < y : out = out + x else: out = out + y out = out + 1 i = i + 1 return out forward_net = IfInWhileNet() i = Tensor(np.array(0), dtype=ms.int32) x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y, i) ``` When the condition of `while` is a variable, the `while` loop body cannot be expanded. The expressions in the `while` loop body are calculated during the running of each step. Therefore, computation types other than tensor, such as scalar, list, and tuple operations cannot exist in the loop body. These types of computation need to be completed during graph build, which conflicts with the computation mechanism of `while` during execution. As shown in example 7, the condition `i < 3` is a variable condition, but the `j = j + 1` scalar computation operation exists in the loop body. As a result, an error occurs during graph build. Example 7: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms class IfInWhileNet(nn.Cell): def __init__(self): super().__init__() self.nums = [1, 2, 3] def construct(self, x, y, i): j = 0 out = x while i < 3: if x + i < y : out = out + x else: out = out + y out = out + self.nums[j] i = i + 1 j = j + 1 return out forward_net = IfInWhileNet() i = Tensor(np.array(0), dtype=ms.int32) x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y, i) ``` The error information in example 7 is as follows: ```text IndexError: mindspore/core/abstract/prim_structures.cc:178 InferTupleOrListGetItem] list_getitem evaluator index should be in range[-3, 3), but got 3. ``` When the `while` condition is a variable, the input shape of the operator cannot be changed in the loop body. MindSpore requires that the input shape of the same operator on the network be determined during graph build. However, changing the input shape of the operator in the `while` loop body takes effect during graph execution. As shown in example 8, the condition `i < 3` is a variable condition, and `while` is not expanded. The `ExpandDims` operator in the loop body changes the input shape of the expression `out = out + 1` in the next loop. As a result, an error occurs during graph build. Example 8: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore.common import dtype as ms from mindspore import ops class IfInWhileNet(nn.Cell): def __init__(self): super().__init__() self.expand_dims = ops.ExpandDims() def construct(self, x, y, i): out = x while i < 3: if x + i < y : out = out + x else: out = out + y out = out + 1 out = self.expand_dims(out, -1) i = i + 1 return out forward_net = IfInWhileNet() i = Tensor(np.array(0), dtype=ms.int32) x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = forward_net(x, y, i) ``` The error information in example 8 is as follows: ```text ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (1, 1), shape2 = (1).. ``` ## Constraints In addition to the constraints in the conditional variable scenario, the current process statement has constraints in other specific scenarios. ### Side Effect When a process control statement with a variable condition is used, the network model generated after graph build contains the control flow operator. In this scenario, the forward graph is executed twice. In this case, if the forward graph contains side effect operators such as `Assign` in a training scenario, the computation result of the backward graph is inconsistent with the expected result. As shown in example 9, the expected gradient of `x` is 2, but the actual gradient is 3. The reason is that the forward graph is executed twice so that `tmp = self.var + 1` and `self.assign(self.var, tmp)` are executed twice, separately. `out = (self.var + 1) * x` is actually `out = (2 + 1) * x`, so the gradient result is incorrect. Example 9: ```python import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore import dtype as ms from mindspore import ops from mindspore.ops import composite from mindspore import Parameter class ForwardNet(nn.Cell): def __init__(self): super().__init__() self.var = Parameter(Tensor(np.array(0), ms.int32)) self.assign = ops.Assign() def construct(self, x, y): if x < y: tmp = self.var + 1 self.assign(self.var, tmp) out = (self.var + 1) * x out = out + 1 return out class BackwardNet(nn.Cell): def __init__(self, net): super(BackwardNet, self).__init__(auto_prefix=False) self.forward_net = net self.grad = composite.GradOperation() def construct(self, *inputs): grads = self.grad(self.forward_net)(*inputs) return grads forward_net = ForwardNet() backward_net = BackwardNet(forward_net) x = Tensor(np.array(0), dtype=ms.int32) y = Tensor(np.array(1), dtype=ms.int32) output = backward_net(x, y) print("output:", output) ``` The execution result is as follows: ```text output: 3 ``` The following table lists the side effect operators that are not supported in the control flow training scenario. | Side Effect List | | --------------------- | | Print | | Assign | | AssignAdd | | AssignSub | | ScalarSummary | | ImageSummary | | TensorSummary | | HistogramSummary | | ScatterAdd | | ScatterDiv | | ScatterMax | | ScatterMin | | ScatterMul | | ScatterNdAdd | | ScatterNdSub | | ScatterNdUpadte | | ScatterNonAliasingAdd | | ScatterSub | | ScatterUpdate | ### Dead Cycle If the value of `cond` in expression `while cond:` is always a scalar `True`, no matter whether there is a `break` or `return` in while body, an unexpected exception may be raised.