mindspore.train.TrainFaultTolerance
- class mindspore.train.TrainFaultTolerance(ckpt_save_path=None, **kwargs)[source]
This callback is used to enable the TFT feature MindIO TFT and will execute TFT operations during training process, such as TFT init, report and exception handle.
Note
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
- Parameters
ckpt_save_path (str) – Checkpoint save directory when failure occurs. When saved, a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}' is created in that directory. Default:
None.kwargs (dict) – Other dictionary type parameters. When argument ckpt_save_path is
None, kwargs must provide a parameter named ckpt_save_fn, which points to a function used to save checkpoint. The prototype of ckpt_save_fn isdef save_ckpt(cb_params, append_dict). When both ckpt_save_path and ckpt_save_fn are provided, ckpt_save_fn is used in priority.
- Raises
Exception – TFT init failed.
ModuleNotFoundError – Mindio TFT whl package is not installed.
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
It's recommended to use the msrun startup method. Please see the msrun start up for more details.
This example should be run with 4 devices.
>>> import numpy as np >>> import os >>> import math >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, Parameter, train >>> from mindspore.communication import init, get_rank >>> from mindspore.common.initializer import initializer, HeUniform >>> from mindspore.train import Model, TrainFaultTolerance >>> from mindspore import dataset as ds >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2') >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) >>> init() >>> ms.set_seed(1) >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": ... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())}) >>> class MatMulCell(nn.Cell): ... def __init__(self, param=None, shape=None): ... super().__init__() ... if shape is None: ... shape = [28 * 28, 512] ... weight_init = HeUniform(math.sqrt(5)) ... self.param = Parameter(initializer(weight_init, shape), name="param") ... if param is not None: ... self.param = param ... self.print = ops.Print() ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... out = self.matmul(x, self.param) ... self.print("out is:", out) ... return out >>> >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = nn.Flatten() ... self.layer1 = MatMulCell() ... self.relu1 = nn.ReLU() ... self.layer2 = nn.Dense(512, 512) ... self.relu2 = nn.ReLU() ... self.layer3 = nn.Dense(512, 10) ... ... def construct(self, x): ... x = self.flatten(x) ... x = self.layer1(x) ... x = self.relu1(x) ... x = self.layer2(x) ... x = self.relu2(x) ... logits = self.layer3(x) ... return logits >>> >>> net = Network() >>> net.layer1.pipeline_stage = 0 >>> net.relu1.pipeline_stage = 0 >>> net.layer2.pipeline_stage = 0 >>> net.relu2.pipeline_stage = 1 >>> net.layer3.pipeline_stage = 1 >>> >>> def create_dataset(batch_size): ... dataset_path = os.getenv("DATA_PATH") ... dataset = ds.MnistDataset(dataset_path) ... image_transforms = [ ... ds.vision.Rescale(1.0 / 255.0, 0), ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), ... ds.vision.HWC2CHW() ... ] ... label_transform = ds.transforms.TypeCast(ms.int32) ... dataset = dataset.map(image_transforms, 'image') ... dataset = dataset.map(label_transform, 'label') ... dataset = dataset.batch(batch_size) ... return dataset >>> >>> dataset = create_dataset(32) >>> >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4) >>> net_with_loss.set_train() >>> model = Model(net_with_loss, optimizer=optimizer_wrapper) >>> tft_cb = TrainFaultTolerance() >>> loss_cb = train.LossMonitor(1) >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
- end(run_context)[source]
Unregister MindIO TFT on train end.
- Parameters
run_context (RunContext) – Context of the train running. Refer to
mindspore.train.RunContextfor detail.
- classmethod get_optimizer_wrapper(origin_opt_cls)[source]
Optimizer wrapper func when using tft.
- Parameters
origin_opt_cls (Class) – origin optimizer class.
- on_train_begin(run_context)[source]
Register train params to MindIO TFT on train beginning.
- Parameters
run_context (RunContext) – Context of the train running. Refer to
mindspore.train.RunContextfor detail.
- on_train_step_end(run_context)[source]
Report status to MindIO TFT after every step finished.
- Parameters
run_context (RunContext) – Context of the train running. Refer to
mindspore.train.RunContextfor detail.