Federated-FLModel

class mindspore_federated.FLModel(yaml_data, network, loss_fn=None, optimizers=None, metrics=None, eval_network=None)[source]

High-level API for training and inference of the vertical federated learning. The FLModel groups networks, optimizers, and other data structures into a high-level object. Then the FLModel builds the vertical federated learning process according to the yaml file provided by the developer, and provides interfaces controlling the training and inference processes.

Parameters
  • yaml_data (class) – Data class containing information on the vertical federated learning process, including optimizers, gradient calculators, etc. The information mentioned above is parsed from the yaml file provided by the developer.

  • network (Cell) – Training network, which outputs the loss. If loss_fn is not specified, the network will be used as the training network directly. If loss_fn is specified, the training network will be constructed on the basis of network and loss_fn.

  • loss_fn (Cell) – Loss function. If not specified, the input network will be used as the training network. Default: None.

  • optimizers (Cell) – Customized optimizer for training the train_network. If optimizers is None, FLModel will try to use standard optimizers of MindSpore specified in the yaml file. Default: None.

  • metrics (Metric) – Metrics to evaluate the evaluation network. Default: None.

  • eval_network (Cell) – Evaluation network of the party, which outputs the predict value. Default: None.

Examples

>>> from mindspore_federated import FLModel, FLYamlData
>>> import mindspore.nn as nn
>>> yaml_data = FLYamlData(os.path.join(os.getcwd(), 'net.yaml'))
>>> # define the training network
>>> train_net = TrainNet()
>>> # define the evaluation network
>>> eval_net = EvalNet()
>>> eval_metric = nn.Accuracy()
>>> party_fl_model = FLModel(yaml_data, train_net, metrics=eval_metric, eval_network=eval_net)
backward_one_step(local_data_batch: dict = None, remote_data_batch: dict = None, sens: dict = None)[source]

Backward the training network using a data batch.

Parameters
  • local_data_batch (dict) – Data batch read from local server. Key is the name of the data item, Value is the corresponding tensor.

  • remote_data_batch (dict) – Data batch read from remote server of other parties. Key is the name of the data item, Value is the corresponding tensor.

  • sens (dict) – Sense parameters or scale values to calculate the gradient values of the training network. Key is the label name specified in the yaml file. Value is the dict of sense parameters or gradient scale values. the Key of the Value dict is the name of the output of the training network, and the Value of the Value dict is the sense tensor of corresponding output.

Returns

Dict, sense parameters or gradient scale values sending to other parties. Key is the label name specified in the yaml file. Value is the dict of sense parameters or gradient scale values. the Key of the Value dict is the input of the training network, and the Value of the Value dict is the sense tensor of corresponding input.

Examples

>>> head_scale = party_fl_model.backward_one_step(item, backbone_out)
eval_one_step(local_data_batch: dict = None, remote_data_batch: dict = None)[source]

Execute the evaluation network using a data batch.

Parameters
  • local_data_batch (dict) – Data batch read from local server. Key is the name of the data item, Value is the corresponding tensor.

  • remote_data_batch (dict) – Data batch read from remote server of other parties. Key is the name of the data item, Value is the corresponding tensor.

Returns

Dict, outputs of the evaluation network. Key is the name of output, Value is tensors.

Examples

>>> party_fl_model.eval_one_step(eval_item, embedding)
forward_one_step(local_data_batch: dict = None, remote_data_batch: dict = None)[source]

Forward the training network using a data batch.

Parameters
  • local_data_batch (dict) – Data batch read from local server. Key is the name of the data item, Value is the corresponding tensor.

  • remote_data_batch (dict) – Data batch read from remote server of other parties. Key is the name of the data item, Value is the corresponding tensor.

Returns

Dict, outputs of the training network. Key is the name of output, Value is the tensor.

Examples

>>> logit_out = party_fl_model.forward_one_step(item, backbone_out)
get_compress_configs()[source]

Load the communication compression configs set in yaml_data, and return the configs for communicator.

Note

Cannot use different compress methods if the names of tensors are the same.

Returns

Dict, Key is the name of tensor, Value is the tensor.

Examples

>>> compress_configs = party_fl_model.get_compress_configs()
load_ckpt(phrase: str = 'eval', path: str = None)[source]

Load checkpoints for the training network and the evaluation network.

Parameters
  • phrase (str) – Load checkpoint to either training network (if set ‘eval’) or evaluation network (if set ‘train’). Default: ‘eval’.

  • path (str) – Path to load the checkpoint. If not specified, using the ckpt_path specified in the yaml file. Default: None.

Examples

>>> party_fl_model.load_ckpt(phrase="eval", path="party_fl_model.ckpt")
save_ckpt(path: str = None)[source]

Save checkpoints of the training network.

Parameters

path (str) – Path to save the checkpoint. If not specified, using the ckpt_path specified in the yaml file. Default: None.

Examples

>>> party_fl_model.save_ckpt("party_fl_model.ckpt")
class mindspore_federated.FLYamlData(yaml_path: str)[source]

Data class storing configuration information on the vertical federated learning process, including inputs, outputs, and hyper-parameters of networks, optimizers, operators, etc. The information mentioned above is parsed from the yaml file provided by the developer of the vertical federated learning system. The class will verify the yaml file in the parsing process. The return value is used for the first input of FLModel.

Parameters

yaml_path (str) – Path of the yaml file.

Examples

>>> from mindspore_federated import FLYamlData
>>> yaml_data = FLYamlData(os.path.join(os.getcwd(), 'net.yaml'))