mindspore.parallel.strategy.get_strategy_metadata
- mindspore.parallel.strategy.get_strategy_metadata(network, rank_id=None)[源代码]
获取当前网络的所有卡的在线策略信息。 关于 Layout 的解释,请参考
mindspore.parallel.Layout
。- 参数:
network (str) - 训练网络的名称。
rank_id (int, 可选) - 指定卡号。默认为
None
,表示返回所有卡的策略。
- 返回:
Dict,返回一个字典,包含所有卡或特定卡的参数切分策略信息。 字典的键是 rank_id,值是该卡所有参数的切分策略。 在每个卡的策略中,key 是参数名称,value 是该参数的切分策略。 如果指定了 rank_id,字典将返回该卡的策略信息;否则,返回网络中所有卡的策略信息。 不支持场景,则返回
None
。
样例:
>>> import mindspore as ms >>> from mindspore import nn >>> from mindspore.communication import init >>> from mindspore.nn.utils import no_init_parameters >>> from mindspore.parallel.auto_parallel import AutoParallel >>> from mindspore.train import Model >>> from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata, ... enable_save_strategy_online, clear_strategy_metadata >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> ms.set_seed(1) >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> with no_init_parameters(): ... net = LeNet5() ... optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> train_net = AutoParallel(net, parallel_mode="semi_auto") >>> model = Model(network=train_net, loss_fn=loss, optimizer=optim, metrics=None) >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> >>> enable_save_strategy_online() >>> model.train(2, dataset) >>> >>> global_info = get_strategy_metadata(network=model.train_network) >>> rank0_info = get_strategy_metadata(network=model.train_network, rank_id=0) >>> local_info = get_current_strategy_metadata(network=model.train_network) >>> clear_strategy_metadata()