# Compiling Performance Optimization for Static Graph Network [![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_source_en.png)](https://gitee.com/mindspore/docs/blob/r2.0/tutorials/experts/source_en/network/op_overload.md) ## Overview In the deep learning network training or inference, the end-to-end time consumption of the network is basically composed of compiling time and running time, especially in the inference scenario, the compiling time is often much larger than the running time, so optimizing the compiling performance is of great importance to improve the deployment effect of the network in real applications. In static graph mode, some scenarios can optimize network compiling performance by changing network writing style, using equivalent semantic substitution, or setting compiling options to change the compiling mechanism. ## Optimizing Compiling Performance with HyperMap `HyperMap` is a special class. The class object construction needs to pass in the mapping function f. When calling the object needs to pass in the n-argument sequence of f, refer to [HyperMap](https://www.mindspore.cn/docs/en/r2.0/api_python/ops/mindspore.ops.HyperMap.html) for more details. The mapping function f must be of type `MultitypeFuncGraph`. Refer to [MultitypeFuncGraph](https://www.mindspore.cn/docs/en/r2.0/api_python/ops/mindspore.ops.MultitypeFuncGraph.html) for more details. When batch processing list elements by using for loops, web compiling performance can be optimized by `HyperMap` equivalent semantic substitution. A sample code that uses HyperMap to optimize compiling performance instead of a for loop is as follows: ```python import time from mindspore.ops import MultitypeFuncGraph, HyperMap from mindspore import ops, Tensor import mindspore as ms add = MultitypeFuncGraph('add') @add.register("Tensor", "Tensor") def add_Tensor(x, y): return ops.add(x, y) add_map = HyperMap(add) list1 = [Tensor(i) for i in range(200)] list2 = [Tensor(i) for i in range(200)] @ms.jit def hyper_map_net(): output = add_map(list1, list2) return output start_time = time.time() output = hyper_map_net() end_time = time.time() print("hyper map cost time:", end_time - start_time) @ms.jit def for_loop_net(): out = [] for i in range(200): out.append(i+i) return out start_time = time.time() for_loop_net() end_time = time.time() print("for loop cost time:", end_time - start_time) ``` ```text hyper map cost time: 0.1894233226776123 for loop cost time: 1.2634551525115967 ``` ## Optimizing Compiling Performance with Select Operator When writing a network, you will often use if statements. If the condition of the if statement is a variable condition, each if statement will generate additional subgraphs. The use of if statements can be found at: [if statements](https://mindspore.cn/tutorials/experts/en/r2.0/network/control_flow.html#if-statement). In static graph mode, the more the number of subgraphs, the longer the compiling time, so some scenarios can optimize the compiling performance by replacing the if statement equivalently with the `Select` operator. It should be noted that using `Select` operator to replace if statement affects the performance of the network. On the one hand, `Select` operator executes both the true and false branches, while the if statement executes only one branch, so it takes less time to run with the if statement compared to the `Select` operator; on the other hand, the `Select` operator outperforms the control flow operator generated by the if statement, and it takes more time to run with the if statement compared to the `Select` operator. Combining the above two factors, the final performance change needs to be judged based on the actual situation. In general, when the number of operators in a branch is small, it is recommended to use the `Select` operator. When the number of operators in a branch is large, it is recommended to use the if statement. A sample code that uses the `Select` operator instead of if statement to optimize compiling performance is as follows: ```python import time from mindspore import ops import mindspore as ms @ms.jit def if_net(x, y): out = 0 for _ in range(100): if x < y: x = x - y else: x = x + y out = out + x return out start_time = time.time() out = if_net(ms.Tensor([0]), ms.Tensor([1])) end_time = time.time() print("if net cost time:", end_time - start_time) @ms.jit def select_net(x, y): out = x for _ in range(100): cond = x < y x = ops.Select()(cond, x - y, x + y) out = out + x return out start_time = time.time() out = select_net(ms.Tensor([0]), ms.Tensor([1])) end_time = time.time() print("select net cost time:", end_time - start_time) ``` ```text if net cost time: 1.1603329181671143 select net cost time: 0.483151912689209 ``` ## Optimizing Compiling Performance with Compiling Cache If no changes are made to a network structure during training or inference, the compiling time can be reduced by using a compiling cache. The essence of the compiling cache is that it stores the compiling intermediate process file of the network model. When the network model remains unchanged, the compiling intermediate process file produced is also the same, so the intermediate process file produced by the last programming can be reused. Refer to `enable_compile_cache` in [set context](https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore/mindspore.set_context.html) for more details. A sample code to optimize compiling performance by enabling compiling cache is as follows: ```python import time from mindspore import set_context from mindspore import dtype import mindspore as ms @ms.jit def func(input_x, input_y): output = input_x for _ in range(200): output = input_x + input_x * input_y + output return output set_context(enable_compile_cache=False) x = ms.Tensor([1], dtype.float32) y = ms.Tensor([2], dtype.float32) start_time = time.time() out = func(x, y) end_time = time.time() print("Disable comile_cache cost time:", end_time - start_time) ``` ```text Disable comile_cache cost time: 0.5485098361968994 ``` The scenario of the above test sample is to turn the compiling cache state off. The above test sample is executed twice, and the time consumption of the first time and the second time are the following: (the actual time consumption is related to hardware environment, and the following data is for reference only) ```text Disable comile_cache cost time: 0.5485098361968994 Disable comile_cache cost time: 0.4614279270172119 ``` When the compiling cache is turned off, the time consumption of the first time and the second time is basically close. ```python import time from mindspore import set_context from mindspore import dtype import mindspore as ms @ms.jit def func(input_x, input_y): output = input_x for _ in range(200): output = input_x + input_x * input_y + output return output set_context(enable_compile_cache=True, compile_cache_path="my_compile_cache") x = ms.Tensor([1], dtype.float32) y = ms.Tensor([2], dtype.float32) start_time = time.time() out = func(x, y) end_time = time.time() print("Enable comile_cache cost time:", end_time - start_time) ``` ```text Enable comile_cache cost time: 0.6357541084289551 ``` The scenario of the above test sample is to turn the compiling cache state on. The above test sample is executed twice, and the time consumption of the first time and the second time are the following: (the actual time consumption is related to hardware environment, and the following data is for reference only) ```text Enable comile_cache cost time: 0.6357541084289551 Enable comile_cache cost time: 0.09379792213439941 ``` When the compiling cache is turned on, the time consumption of the second execution takes about 1/7 of that of the first execution. ## Optimizing Compiling Performance with vmap It is currently known that MindSpore supports the vmap, which can be used instead of a for loop to optimize compiling performance when batch data without dependencies is processed and the associated operator supports the vmap. For detailed introduction, refer to [vmap](https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore/mindspore.vmap.html). It should be noted that vmap optimizes not only compiling performance, but also runtime performance. A sample code that uses vmap instead of a for loop to process batch data to optimize compiling performance is as follows: ```python import numpy as np import time from mindspore import ops, vmap import mindspore as ms def hswish_func(x): return ops.HSwish()(x) @ms.jit def manually_batched(xs): output = [] for i in range(xs.shape[0]): output.append(hswish_func(xs[i])) return ops.stack(output) shape = (100, 2) prop = 100 x_np = (np.random.randn(*shape) * prop).astype(np.float32) x = ms.Tensor(x_np) x = ops.sub(x, 0) start_time = time.time() output_vmap = vmap(hswish_func, in_axes=(0,))(x) end_time = time.time() print("vmap cost time:", end_time - start_time) start_time = time.time() output_manually = manually_batched(x) end_time = time.time() print("for loop cost time:", end_time - start_time) ``` ```text vmap cost time: 0.05766916275024414 for loop cost time: 1.9284062385559082 ``` In the above sample, it is equivalent to processing 100 sets of Tensor data in batch, and you can see that the performance of processing with vmap exceeds the performance of processing with a for loop by 30 times.