mindarmour

MindArmour, a tool box of MindSpore to enhance model trustworthiness and achieve privacy-preserving machine learning.

class mindarmour.Attack[source]

The abstract base class for all attack classes creating adversarial examples.

batch_generate(inputs, labels, batch_size=64)[source]

Generate adversarial examples in batch, based on input samples and their labels.

Parameters
  • inputs (Union[numpy.ndarray, tuple]) – Samples based on which adversarial examples are generated.

  • labels (Union[numpy.ndarray, tuple]) – Original/target labels. For each input if it has more than one label, it is wrapped in a tuple.

  • batch_size (int) – The number of samples in one batch.

Returns

numpy.ndarray, generated adversarial examples

Examples

>>> inputs = np.array([[0.2, 0.4, 0.5, 0.2], [0.7, 0.2, 0.4, 0.3]])
>>> labels = np.array([3, 0])
>>> advs = attack.batch_generate(inputs, labels, batch_size=2)
abstract generate(inputs, labels)[source]

Generate adversarial examples based on normal samples and their labels.

Parameters
  • inputs (Union[numpy.ndarray, tuple]) – Samples based on which adversarial examples are generated.

  • labels (Union[numpy.ndarray, tuple]) – Original/target labels. For each input if it has more than one label, it is wrapped in a tuple.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.BlackModel[source]

The abstract class which treats the target model as a black box. The model should be defined by users.

is_adversarial(data, label, is_targeted)[source]

Check if input sample is adversarial example or not.

Parameters
  • data (numpy.ndarray) – The input sample to be check, typically some maliciously perturbed examples.

  • label (numpy.ndarray) – For targeted attacks, label is intended label of perturbed example. For untargeted attacks, label is original label of corresponding unperturbed sample.

  • is_targeted (bool) – For targeted/untargeted attacks, select True/False.

Returns

bool.
  • If True, the input sample is adversarial.

  • If False, the input sample is not adversarial.

abstract predict(inputs)[source]

Predict using the user specified model. The shape of predict results should be (m, n), where n represents the number of classes this model classifies.

Parameters

inputs (numpy.ndarray) – The input samples to be predicted.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.DPModel(micro_batches=2, norm_bound=1.0, noise_mech=None, clip_mech=None, **kwargs)[source]

This class is overload mindspore.train.model.Model.

Parameters
  • micro_batches (int) – The number of small batches split from an original batch. Default: 2.

  • norm_bound (float) – Use to clip the bound, if set 1, will return the original data. Default: 1.0.

  • noise_mech (Mechanisms) – The object can generate the different type of noise. Default: None.

  • clip_mech (Mechanisms) – The object is used to update the adaptive clip. Default: None.

Raises
  • ValueError – If DPOptimizer and noise_mecn are both None or not None.

  • ValueError – If noise_mech or DPOtimizer’s mech method is adaptive while clip_mech is not None.

Examples

>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> micro_batches = 2
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
>>> factory_opt.set_mechanisms('Gaussian',
>>>                            norm_bound=norm_bound,
>>>                            initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>>                                          learning_rate=0.1, momentum=0.9)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>>                                            decay_policy='Linear',
>>>                                            learning_rate=0.01,
>>>                                            target_unclipped_quantile=0.9,
>>>                                            fraction_stddev=0.01)
>>> model = DPModel(micro_batches=micro_batches,
>>>                 norm_bound=norm_bound,
>>>                 clip_mech=clip_mech,
>>>                 noise_mech=None,
>>>                 network=network,
>>>                 loss_fn=loss,
>>>                 optimizer=net_opt,
>>>                 metrics=None)
>>> ms_ds = ds.GeneratorDataset(dataset_generator,
>>>                             ['data', 'label'])
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
class mindarmour.Defense(network)[source]

The abstract base class for all defense classes defending adversarial examples.

Parameters

network (Cell) – A MindSpore-style deep learning model to be defensed.

batch_defense(inputs, labels, batch_size=32, epochs=5)[source]

Defense model with samples in batch.

Parameters
  • inputs (numpy.ndarray) – Samples based on which adversarial examples are generated.

  • labels (numpy.ndarray) – Labels of input samples.

  • batch_size (int) – Number of samples in one batch.

  • epochs (int) – Number of epochs.

Returns

numpy.ndarray, loss of batch_defense operation.

Raises

ValueError – If batch_size is 0.

abstract defense(inputs, labels)[source]

Defense model with samples.

Parameters
  • inputs (numpy.ndarray) – Samples based on which adversarial examples are generated.

  • labels (numpy.ndarray) – Labels of input samples.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.Detector[source]

The abstract base class for all adversarial example detectors.

abstract detect(inputs)[source]

Detect adversarial examples from input samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be detected.

Raises

NotImplementedError – It is an abstract method.

abstract detect_diff(inputs)[source]

Calculate the difference between the input samples and de-noised samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be detected.

Raises

NotImplementedError – It is an abstract method.

abstract fit(inputs, labels=None)[source]

Fit a threshold and refuse adversarial examples whose difference from their denoised versions are larger than the threshold. The threshold is determined by a certain false positive rate when applying to normal samples.

Parameters
Raises

NotImplementedError – It is an abstract method.

abstract transform(inputs)[source]

Filter adversarial noises in input samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be transformed.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.Fuzzer(target_model, train_dataset, neuron_num, segmented_num=1000)[source]

Fuzzing test framework for deep neural networks.

Reference: DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep Neural Networks

Parameters
  • target_model (Model) – Target fuzz model.

  • train_dataset (numpy.ndarray) – Training dataset used for determining the neurons’ output boundaries.

  • neuron_num (int) – The number of testing neurons.

  • segmented_num (int) – The number of segmented sections of neurons’ output intervals. Default: 1000.

Examples

>>> net = Net()
>>> mutate_config = [{'method': 'Blur',
>>>                   'params': {'auto_param': [True]}},
>>>                  {'method': 'Contrast',
>>>                   'params': {'factor': [2]}},
>>>                  {'method': 'Translate',
>>>                   'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}},
>>>                  {'method': 'FGSM',
>>>                   'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}]
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds)
fuzzing(mutate_config, initial_seeds, coverage_metric=KMNC, eval_metrics=auto, max_iters=10000, mutate_num_per_seed=20)[source]

Fuzzing tests for deep neural networks.

Parameters
  • mutate_config (list) – Mutate configs. The format is [{‘method’: ‘Blur’, ‘params’: {‘radius’: [0.1, 0.2], ‘auto_param’: [True, False]}}, {‘method’: ‘Contrast’, ‘params’: {‘factor’: [1, 1.5, 2]}}, {‘method’: ‘FGSM’, ‘params’: {‘eps’: [0.3, 0.2, 0.4], ‘alpha’: [0.1]}}, …]. The supported methods list is in self._strategies, and the params of each method must within the range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based transform methods include: ‘Contrast’, ‘Brightness’, ‘Blur’ and ‘Noise’. Secondly, affine transform methods include: ‘Translate’, ‘Scale’, ‘Shear’ and ‘Rotate’. Thirdly, attack methods include: ‘FGSM’, ‘PGD’ and ‘MDIIM’. mutate_config must have method in the type of pixel value based transform methods. The way of setting parameters for first and second type methods can be seen in ‘mindarmour/fuzz_testing/image_transform.py’. For third type methods, the optional parameters refer to self._attack_param_checklists.

  • initial_seeds (list[list]) – Initial seeds used to generate mutated samples. The format of initial seeds is [[image_data, label], […], …] and the label must be one-hot.

  • coverage_metric (str) – Model coverage metric of neural networks. All supported metrics are: ‘KMNC’, ‘NBC’, ‘SNAC’. Default: ‘KMNC’.

  • eval_metrics (Union[list, tuple, str]) – Evaluation metrics. If the type is ‘auto’, it will calculate all the metrics, else if the type is list or tuple, it will calculate the metrics specified by user. All supported evaluate methods are ‘accuracy’, ‘attack_success_rate’, ‘kmnc’, ‘nbc’, ‘snac’. Default: ‘auto’.

  • max_iters (int) – Max number of select a seed to mutate. Default: 10000.

  • mutate_num_per_seed (int) – The number of mutate times for a seed. Default: 20.

Returns

  • list, mutated samples in fuzz_testing.

  • list, ground truth labels of mutated samples.

  • list, preds of mutated samples.

  • list, strategies of mutated samples.

  • dict, metrics report of fuzzer.

Raises
  • TypeError – If the type of eval_metrics is not str, list or tuple.

  • TypeError – If the type of metric in list eval_metrics is not str.

  • ValueError – If eval_metrics is not equal to ‘auto’ when it’s type is str.

  • ValueError – If metric in list eval_metrics is not in [‘accuracy’, ‘attack_success_rate’, ‘kmnc’, ‘nbc’, ‘snac’].

class mindarmour.ImageInversionAttack(network, input_shape, input_bound, loss_weights=(1, 0.2, 5))[source]

An attack method used to reconstruct images by inverting their deep representations.

References: Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. 2014.

Parameters
  • network (Cell) – The network used to infer images’ deep representations.

  • input_shape (tuple) – Data shape of single network input, which should be in accordance with the given network. The format of shape should be (channel, image_width, image_height).

  • input_bound (Union[tuple, list]) – The pixel range of original images, which should be like [minimum_pixel, maximum_pixel] or (minimum_pixel, maximum_pixel).

  • loss_weights (Union[list, tuple]) – Weights of three sub-loss in InversionLoss, which can be adjusted to obtain better results. Default: (1, 0.2, 5).

Raises
  • TypeError – If the type of network is not Cell.

  • ValueError – If any value of input_shape is not positive int.

  • ValueError – If any value of loss_weights is not positive value.

evaluate(original_images, inversion_images, labels=None, new_network=None)[source]

Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between original images and inversion images, and the average of inverted images’ confidence on true labels of inverted inferred by a new trained network.

Parameters
  • original_images (numpy.ndarray) – Original images, whose shape should be (img_num, channels, img_width, img_height).

  • inversion_images (numpy.ndarray) – Inversion images, whose shape should be (img_num, channels, img_width, img_height).

  • labels (numpy.ndarray) – Ground truth labels of original images. Default: None.

  • new_network (Cell) – A network whose structure contains all parts of self._network, but loaded with different checkpoint file. Default: None.

Returns

tuple, average l2 distance, average ssim value and average confidence (if labels or new_network is None, then average confidence would be None).

Examples

>>> net = LeNet5()
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1),
>>> loss_weights=[1, 0.2, 5])
>>> features = np.random.random((2, 10)).astype(np.float32)
>>> inver_images = inversion_attack.generate(features, iters=10)
>>> ori_images = np.random.random((2, 1, 32, 32))
>>> result = inversion_attack.evaluate(ori_images, inver_images)
>>> print(len(result))
3
generate(target_features, iters=100)[source]

Reconstruct images based on target_features.

Parameters
  • target_features (numpy.ndarray) – Deep representations of original images. The first dimension of target_features should be img_num. It should be noted that the shape of target_features should be (1, dim2, dim3, …) if img_num equals 1.

  • iters (int) – iteration times of inversion attack, which should be positive integers. Default: 100.

Returns

numpy.ndarray, reconstructed images, which are expected to be similar to original images.

Raises
  • TypeError – If the type of target_features is not numpy.ndarray.

  • ValueError – If any value of iters is not positive int.Z

Examples

>>> net = LeNet5()
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1),
>>> loss_weights=[1, 0.2, 5])
>>> features = np.random.random((2, 10)).astype(np.float32)
>>> images = inversion_attack.generate(features, iters=10)
>>> print(images.shape)
(2, 1, 32, 32)
class mindarmour.MembershipInference(model, n_jobs=- 1)[source]

Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. The attack requires loss or logits results of training samples.

References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. Membership Inference Attacks against Machine Learning Models. 2017.

Parameters
  • model (Model) – Target model.

  • n_jobs (int) – Number of jobs run in parallel. -1 means using all processors, otherwise the value of n_jobs must be a positive integer.

Examples

>>> # train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> # test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> attack_model = MembershipInference(model, n_jobs=-1)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> attack_model.train(train_1, test_1, config)
>>> metrics = ["precision", "recall", "accuracy"]
>>> result = attack_model.eval(train_2, test_2, metrics)
Raises
  • TypeError – If type of model is not mindspore.train.Model.

  • TypeError – If type of n_jobs is not int.

  • ValueError – The value of n_jobs is neither -1 nor a positive integer.

eval(dataset_train, dataset_test, metrics)[source]

Evaluate the different privacy of the target model. Evaluation indicators shall be specified by metrics.

Parameters
  • dataset_train (mindspore.dataset) – The training dataset for the target model.

  • dataset_test (mindspore.dataset) – The test dataset for the target model.

  • metrics (Union[list, tuple]) – Evaluation indicators. The value of metrics must be in [“precision”, “accuracy”, “recall”]. Default: [“precision”].

Returns

list, each element contains an evaluation indicator for the attack model.

train(dataset_train, dataset_test, attack_config)[source]

Depending on the configuration, use the input dataset to train the attack model. Save the attack model to self._attack_list.

Parameters
  • dataset_train (mindspore.dataset) – The training dataset for the target model.

  • dataset_test (mindspore.dataset) – The test set for the target model.

  • attack_config (Union[list, tuple]) – Parameter setting for the attack model. The format is [{“method”: “knn”, “params”: {“n_neighbors”: [3, 5, 7]}}, {“method”: “lr”, “params”: {“C”: np.logspace(-4, 2, 10)}}]. The support methods are knn, lr, mlp and rf, and the params of each method must within the range of changeable parameters. Tips of params implement can be found below: KNN, LR, RF, MLP.

Raises
  • KeyError – If any config in attack_config doesn’t have keys {“method”, “params”}.

  • NameError – If the method(case insensitive) in attack_config is not in [“lr”, “knn”, “rf”, “mlp”].