mindspore.set_algo_parameters

View Source On Gitee
mindspore.set_algo_parameters(**kwargs)[source]

Set parameters in the algorithm for parallel strategy searching. See a typical use in test_auto_parallel_resnet.py.

Note

The attribute name is required. This interface works ONLY in AUTO_PARALLEL mode.

Parameters
  • fully_use_devices (bool) – Whether ONLY searching strategies that fully use all available devices. Default: True . For example with 8 devices available, if set True , strategy (4, 1) will not be included in ReLU’s candidate strategies, because strategy (4, 1) only utilizes 4 devices.

  • elementwise_op_strategy_follow (bool) – Whether the elementwise operator has the consistent strategies as its subsequent operators. Elementwise operators refer to operators that operate on input element by element, such as Add, ReLU, etc. Default: False . For the example of ReLU followed by Add, if this flag is set True , then the searched strategy by the algorithm guarantees that strategies of these two operators are consistent, e.g., ReLU’s strategy (8, 1) and Add’s strategy ((8, 1), (8, 1)).

  • enable_algo_approxi (bool) – Whether to enable the approximation in the algorithms. Default: False . Due to large solution space in searching parallel strategy for large DNN model, the algorithm takes fairly long time in this case. To mitigate it, if this flag is set True , an approximation is made to discard some candidate strategies, so that the solution space is shrunken.

  • algo_approxi_epsilon (float) – The epsilon value used in the approximation algorithm. Default: 0.1 . This value describes the extent of approximation. For example, the number of candidate strategies of an operator is S, if ‘enable_algo_approxi’ is True , then the remaining strategies is of size: min{S, 1/epsilon}.

  • tensor_slice_align_enable (bool) – Whether to check the shape of tensor slice of MatMul. Default: False . Due to properties of some hardware, MatMul kernel only with large shapes can show advantages. If this flag is True , then the slice shape of MatMul is checked to prevent irregular shapes.

  • tensor_slice_align_size (int) – The minimum tensor slice shape of MatMul, the value must be in [1, 1024]. Default: 16 . If ‘tensor_slice_align_enable’ is set True , then the slice size of last dimension of MatMul tensors should be multiple of this value.

Raises

ValueError – If context keyword is not recognized.

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. Please see the rank table startup for more details.

For the GPU devices, users need to prepare the host file and mpi, please see the mpirun startup .

For the CPU device, users need to write a dynamic cluster startup script, please see the Dynamic Cluster Startup .

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, train
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
...                              search_mode="sharding_propagation")
>>> init()
>>> ms.set_algo_parameters(fully_use_devices=True)
>>> ms.set_algo_parameters(elementwise_op_strategy_follow=True)
>>> ms.set_algo_parameters(enable_algo_approxi=True)
>>> ms.set_algo_parameters(algo_approxi_epsilon=0.2)
>>> ms.set_algo_parameters(tensor_slice_align_enable=True)
>>> ms.set_algo_parameters(tensor_slice_align_size=8)
>>>
>>> # Define the network structure.
>>> class Dense(nn.Cell):
...     def __init__(self, in_channels, out_channels):
...         super().__init__()
...         self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
...         self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
...         self.matmul = ops.MatMul()
...         self.add = ops.Add()
...
...     def construct(self, x):
...         x = self.matmul(x, self.weight)
...         x = self.add(x, self.bias)
...         return x
>>>
>>> class FFN(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = ops.Flatten()
...         self.dense1 = Dense(28*28, 64)
...         self.relu = ops.ReLU()
...         self.dense2 = Dense(64, 10)
...
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.dense1(x)
...         x = self.relu(x)
...         x = self.dense2(x)
...         return x
>>> net = FFN()
>>> net.dense1.matmul.shard(((2, 1), (1, 2)))
>>>
>>> # Create dataset.
>>> step_per_epoch = 16
>>> def get_dataset(*inputs):
...     def generate():
...         for _ in range(step_per_epoch):
...             yield inputs
...     return generate
>>>
>>> input_data = np.random.rand(1, 28, 28).astype(np.float32)
>>> label_data = np.random.rand(1).astype(np.int32)
>>> fake_dataset = get_dataset(input_data, label_data)
>>> dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
>>> # Train network.
>>> optimizer = nn.Momentum(net.trainable_params(), 1e-3, 0.1)
>>> loss_fn = nn.CrossEntropyLoss()
>>> loss_cb = train.LossMonitor()
>>> model = ms.Model(network=net, loss_fn=loss_fn, optimizer=optimizer)
>>> model.train(epoch=2, train_dataset=dataset, callbacks=[loss_cb])