mindspore.parallel.strategy.get_strategy_metadata

查看源文件
mindspore.parallel.strategy.get_strategy_metadata(network, rank_id=None)[源代码]

获取当前网络的所有卡的在线策略信息。 关于 Layout 的解释,请参考 mindspore.parallel.Layout

参数:
  • network (str) - 训练网络的名称。

  • rank_id (int, 可选) - 指定卡号。默认为 None,表示返回所有卡的策略。

返回:

Dict,返回一个字典,包含所有卡或特定卡的参数切分策略信息。 字典的键是 rank_id,值是该卡所有参数的切分策略。 在每个卡的策略中,key 是参数名称,value 是该参数的切分策略。 如果指定了 rank_id,字典将返回该卡的策略信息;否则,返回网络中所有卡的策略信息。 不支持场景,则返回 None

样例:

>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.communication import init
>>> from mindspore.nn.utils import no_init_parameters
>>> from mindspore.parallel.auto_parallel import AutoParallel
>>> from mindspore.train import Model
>>> from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata,
...     enable_save_strategy_online, clear_strategy_metadata
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> ms.set_seed(1)
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> with no_init_parameters():
...     net = LeNet5()
...     optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> train_net = AutoParallel(net, parallel_mode="semi_auto")
>>> model = Model(network=train_net, loss_fn=loss, optimizer=optim, metrics=None)
>>>
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>>
>>> enable_save_strategy_online()
>>> model.train(2, dataset)
>>>
>>> global_info = get_strategy_metadata(network=model.train_network)
>>> rank0_info = get_strategy_metadata(network=model.train_network, rank_id=0)
>>> local_info = get_current_strategy_metadata(network=model.train_network)
>>> clear_strategy_metadata()