mindspore.set_algo_parameters

查看源文件
mindspore.set_algo_parameters(**kwargs)[源代码]

设置并行策略搜索算法中的参数。有关典型用法,请参见 test_auto_parallel_resnet.py

说明

属性名称为必填项。此接口仅在AUTO_PARALLEL模式下工作。

参数:
  • fully_use_devices (bool) - 表示是否仅搜索充分利用所有可用设备的策略。默认值: True 。例如,如果有8个可用设备,当该参数设为 True 时,策略(4, 1)将不包括在ReLU的候选策略中,因为策略(4, 1)仅使用4个设备。

  • elementwise_op_strategy_follow (bool) - 表示elementwise算子是否具有与后续算子一样的策略,elementwise算子是指对输入张量逐元素应用一个函数变换的算子,如Add、ReLU等。默认值: False 。例如,Add的输出给了ReLU,如果该参数设置为 True ,则算法搜索的策略可以保证这两个算子的策略是一致的,例如,ReLU的策略(8, 1)和Add的策略((8, 1), (8, 1))。

  • enable_algo_approxi (bool) - 表示是否在算法中启用近似。默认值: False 。由于大型DNN模型的并行搜索策略有较大的解空间,该算法在这种情况下耗时较长。为了缓解这种情况,如果该参数设置为 True ,则会进行近似丢弃一些候选策略,以便缩小解空间。

  • algo_approxi_epsilon (float) - 表示近似算法中使用的epsilon值。默认值: 0.1 此值描述了近似程度。例如,一个算子的候选策略数量为S,如果 enable_algo_approxiTrue ,则剩余策略的大小为min{S, 1/epsilon}。

  • tensor_slice_align_enable (bool) - 表示是否检查MatMul的tensor切片的shape。默认值: False 受某些硬件的属性限制,只有shape较大的MatMul内核才能显示出优势。如果该参数为 True ,则检查MatMul的切片shape以阻断不规则的shape。

  • tensor_slice_align_size (int) - 表示MatMul的最小tensor切片的shape,该值必须在[1,1024]范围内。默认值: 16 。如果 tensor_slice_align_enable 设为 True ,则MatMul tensor的最后维度的切片大小应该是该值的倍数。

异常:
  • ValueError - 无法识别传入的关键字。

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend设备,用户需要准备rank表,设置rank_id和device_id,详见 rank table启动

针对GPU设备,用户需要准备host文件和mpi,详见 mpirun启动

针对CPU设备,用户需要编写动态组网启动脚本,详见 动态组网启动

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, train
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
...                              search_mode="sharding_propagation")
>>> init()
>>> ms.set_algo_parameters(fully_use_devices=True)
>>> ms.set_algo_parameters(elementwise_op_strategy_follow=True)
>>> ms.set_algo_parameters(enable_algo_approxi=True)
>>> ms.set_algo_parameters(algo_approxi_epsilon=0.2)
>>> ms.set_algo_parameters(tensor_slice_align_enable=True)
>>> ms.set_algo_parameters(tensor_slice_align_size=8)
>>>
>>> # Define the network structure.
>>> class Dense(nn.Cell):
...     def __init__(self, in_channels, out_channels):
...         super().__init__()
...         self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
...         self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
...         self.matmul = ops.MatMul()
...         self.add = ops.Add()
...
...     def construct(self, x):
...         x = self.matmul(x, self.weight)
...         x = self.add(x, self.bias)
...         return x
>>>
>>> class FFN(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = ops.Flatten()
...         self.dense1 = Dense(28*28, 64)
...         self.relu = ops.ReLU()
...         self.dense2 = Dense(64, 10)
...
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.dense1(x)
...         x = self.relu(x)
...         x = self.dense2(x)
...         return x
>>> net = FFN()
>>> net.dense1.matmul.shard(((2, 1), (1, 2)))
>>>
>>> # Create dataset.
>>> step_per_epoch = 16
>>> def get_dataset(*inputs):
...     def generate():
...         for _ in range(step_per_epoch):
...             yield inputs
...     return generate
>>>
>>> input_data = np.random.rand(1, 28, 28).astype(np.float32)
>>> label_data = np.random.rand(1).astype(np.int32)
>>> fake_dataset = get_dataset(input_data, label_data)
>>> dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
>>> # Train network.
>>> optimizer = nn.Momentum(net.trainable_params(), 1e-3, 0.1)
>>> loss_fn = nn.CrossEntropyLoss()
>>> loss_cb = train.LossMonitor()
>>> model = ms.Model(network=net, loss_fn=loss_fn, optimizer=optimizer)
>>> model.train(epoch=2, train_dataset=dataset, callbacks=[loss_cb])