mindarmour.diff_privacy

This module provide Differential Privacy feature to protect user privacy.

class mindarmour.diff_privacy.AdaClippingWithGaussianRandom(decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]

Adaptive clipping. If decay_policy is ‘Linear’, the update formula is norm_bound = norm_bound - learning_rate*(beta - target_unclipped_quantile). If decay_policy is ‘Geometric’, the update formula is norm_bound = norm_bound*exp(-learning_rate*(empirical_fraction - target_unclipped_quantile)). where beta is the empirical fraction of samples with the value at most target_unclipped_quantile.

Parameters
  • decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: Linear.

  • learning_rate (float) – Learning rate of update norm clip. Default: 0.001.

  • target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.

  • fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is empirical_fraction + N(0, fraction_stddev). Default: 0.01.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

Returns

Tensor, undated norm clip .

Examples

>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_bound = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
>>>                                          learning_rate=learning_rate,
>>>                                          target_unclipped_quantile=target_unclipped_quantile,
>>>                                          fraction_stddev=beta_stddev)
>>> next_norm_bound = ada_clip(beta, norm_bound)
construct(empirical_fraction, norm_bound)[source]

Update value of norm_bound.

Parameters
  • empirical_fraction (Tensor) – empirical fraction of samples with the value at most target_unclipped_quantile.

  • norm_bound (Tensor) – Clipping bound for the l2 norm of the gradients.

Returns

Tensor, generated noise with shape like given gradients.

class mindarmour.diff_privacy.ClipMechanismsFactory[source]

Factory class of clip mechanisms

static create(mech_name, decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]
Parameters
  • mech_name (str) – Clip noise generated strategy, support ‘Gaussian’ now.

  • decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: Linear.

  • learning_rate (float) – Learning rate of update norm clip. Default: 0.001.

  • target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.

  • fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$. Default: 0.01.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

Raises

NameErrormech_name must be in [‘Gaussian’].

Returns

Mechanisms, class of noise generated Mechanism.

Examples

>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_bound = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.1
>>> learning_rate = 0.1
>>> target_unclipped_quantile = 0.3
>>> clip_mechanism = ClipMechanismsFactory()
>>> ada_clip = clip_mechanism.create('Gaussian',
>>>                          decay_policy=decay_policy,
>>>                          learning_rate=learning_rate,
>>>                          target_unclipped_quantile=target_unclipped_quantile,
>>>                          fraction_stddev=beta_stddev)
>>> next_norm_bound = ada_clip(beta, norm_bound)
class mindarmour.diff_privacy.DPModel(micro_batches=2, norm_bound=1.0, noise_mech=None, clip_mech=None, **kwargs)[source]

This class is overload mindspore.train.model.Model.

Parameters
  • micro_batches (int) – The number of small batches split from an original batch. Default: 2.

  • norm_bound (float) – Use to clip the bound, if set 1, will return the original data. Default: 1.0.

  • noise_mech (Mechanisms) – The object can generate the different type of noise. Default: None.

  • clip_mech (Mechanisms) – The object is used to update the adaptive clip . Default: None.

Examples

>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> micro_batches = 2
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
>>> factory_opt.set_mechanisms('Gaussian',
>>>                            norm_bound=norm_bound,
>>>                            initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>>                                          learning_rate=0.1, momentum=0.9)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>>                                            decay_policy='Linear',
>>>                                            learning_rate=0.01,
>>>                                            target_unclipped_quantile=0.9,
>>>                                            fraction_stddev=0.01)
>>> model = DPModel(micro_batches=micro_batches,
>>>                 norm_bound=norm_bound,
>>>                 clip_mech=clip_mech,
>>>                 noise_mech=None,
>>>                 network=network,
>>>                 loss_fn=loss,
>>>                 optimizer=net_opt,
>>>                 metrics=None)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>>                             ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size*batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
class mindarmour.diff_privacy.DPOptimizerClassFactory(micro_batches=2)[source]

Factory class of Optimizer.

Parameters

micro_batches (int) – The number of small batches split from an original batch. Default: 2.

Returns

Optimizer, Optimizer class

Examples

>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5)
>>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(),
>>>                                          learning_rate=cfg.lr,
>>>                                          momentum=cfg.momentum)
create(policy, *args, **kwargs)[source]

Create DP optimizer.

Parameters

policy (str) – Choose original optimizer type.

Returns

Optimizer, A optimizer with DP.

set_mechanisms(policy, *args, **kwargs)[source]

Get noise mechanism object.

Parameters

policy (str) – Choose mechanism type.

class mindarmour.diff_privacy.NoiseAdaGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy='Exp')[source]

Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be ‘Time’ mode, ‘Step’ mode, ‘Exp’ mode. self._noise_multiplier will be update during the model.train, using _MechanismsParamsUpdater.

Parameters
  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

  • noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.

  • decay_policy (str) – Noise decay strategy include ‘Step’, ‘Time’, ‘Exp’. Default: ‘Exp’.

Returns

Tensor, generated noise with shape like given gradients.

Examples

>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 1.5
>>> seed = 0
>>> noise_decay_rate = 6e-4
>>> decay_policy = "Exp"
>>> net = NoiseAdaGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_decay_rate, decay_policy)
>>> res = net(gradients)
>>> print(res)
class mindarmour.diff_privacy.NoiseGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None)[source]

Gaussian noise generated mechanism.

Parameters
  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

  • decay_policy (str) – Mechanisms parameters update policy. Default: None.

Returns

Tensor, generated noise with shape like given gradients.

Examples

>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> seed = 0
>>> decay_policy = None
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier, seed, decay_policy)
>>> res = net(gradients)
>>> print(res)
construct(gradients)[source]

Generated Gaussian noise.

Parameters

gradients (Tensor) – The gradients.

Returns

Tensor, generated noise with shape like given gradients.

class mindarmour.diff_privacy.NoiseMechanismsFactory[source]

Factory class of noise mechanisms

static create(mech_name, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy=None)[source]
Parameters
  • mech_name (str) – Noise generated strategy, could be ‘Gaussian’ or ‘AdaGaussian’. Noise would be decayed with ‘AdaGaussian’ mechanism while be constant with ‘Gaussian’ mechanism.

  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

  • noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.

  • decay_policy (str) – Mechanisms parameters update policy. Default: None, no parameters need update. Default: None.

Raises

NameErrormech_name must be in [‘Gaussian’, ‘AdaGaussian’].

Returns

Mechanisms, class of noise generated Mechanism.

Examples

>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>>                                              norm_bound=norm_bound,
>>>                                              initial_noise_multiplier=initial_noise_multiplier)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>>                                            decay_policy='Linear',
>>>                                            learning_rate=0.01,
>>>                                            target_unclipped_quantile=0.9,
>>>                                            fraction_stddev=0.01)
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
>>>                       momentum=0.9)
>>> model = DPModel(micro_batches=2,
>>>                 clip_mech=clip_mech,
>>>                 norm_bound=norm_bound,
>>>                 noise_mech=noise_mech,
>>>                 network=network,
>>>                 loss_fn=loss,
>>>                 optimizer=net_opt,
>>>                 metrics=None)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>>                            ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
class mindarmour.diff_privacy.PrivacyMonitorFactory[source]

Factory class of DP training’s privacy monitor.

static create(policy, *args, **kwargs)[source]

Create a privacy monitor class.

Parameters
  • policy (str) – Monitor policy, ‘rdp’ and ‘zcdp’ are supported by now.

  • args (Union[int, float, numpy.ndarray, list, str]) – Parameters used for creating a privacy monitor.

  • kwargs (Union[int, float, numpy.ndarray, list, str]) – Keyword parameters used for creating a privacy monitor.

Returns

Callback, a privacy monitor.

Examples

>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=32)