mindspore.parallel.strategy.get_strategy_metadata
- mindspore.parallel.strategy.get_strategy_metadata(network, rank_id=None)[source]
Get all params strategy info or specific rank strategy info in this cell. For more information on layouts, please refer to:
mindspore.parallel.Layout
.- Parameters
- Returns
Dict. A dictionary containing the parameter slicing strategies for either all ranks or a specific rank. The key is rank_id, and the value is the slicing strategy for all parameters on that rank. Within each rank's strategy, the key is the parameter name, and the value is the slicing strategy. If a rank_id is specified, the dictionary returns the strategy information for that specific rank. Otherwise, it returns the strategy information for all ranks in the network. If not supported, returns None.
Examples
>>> import mindspore as ms >>> from mindspore import nn >>> from mindspore.communication import init >>> from mindspore.nn.utils import no_init_parameters >>> from mindspore.parallel.auto_parallel import AutoParallel >>> from mindspore.train import Model >>> from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata, ... enable_save_strategy_online, clear_strategy_metadata >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> ms.set_seed(1) >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> with no_init_parameters(): ... net = LeNet5() ... optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> train_net = AutoParallel(net, parallel_mode="semi_auto") >>> model = Model(network=train_net, loss_fn=loss, optimizer=optim, metrics=None) >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> >>> enable_save_strategy_online() >>> model.train(2, dataset) >>> >>> global_info = get_strategy_metadata(network=model.train_network) >>> rank0_info = get_strategy_metadata(network=model.train_network, rank_id=0) >>> local_info = get_current_strategy_metadata(network=model.train_network) >>> clear_strategy_metadata()