mindarmour.privacy.diff_privacy

This module provides Differential Privacy feature to protect user privacy.

class mindarmour.privacy.diff_privacy.AdaClippingWithGaussianRandom(decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]

Adaptive clipping. If decay_policy is ‘Linear’, the update formula \(norm\_bound = norm\_bound - learning\_rate*(beta - target\_unclipped\_quantile)\). If decay_policy is ‘Geometric’, the update formula is \(norm\_bound = norm\_bound*exp(-learning\_rate*(empirical\_fraction - target\_unclipped\_quantile))\). where beta is the empirical fraction of samples with the value at most target_unclipped_quantile.

Parameters
  • decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: ‘Linear’.

  • learning_rate (float) – Learning rate of update norm clip. Default: 0.001.

  • target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.

  • fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is empirical_fraction + N(0, fraction_stddev). Default: 0.01.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

Returns

Tensor, undated norm clip .

Examples

>>> from mindspore import Tensor
>>> from mindspore.common import dtype as mstype
>>> from mindarmour.privacy.diff_privacy import AdaClippingWithGaussianRandom
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_bound = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
...                                          learning_rate=learning_rate,
...                                          target_unclipped_quantile=target_unclipped_quantile,
...                                          fraction_stddev=beta_stddev)
>>> next_norm_bound = ada_clip(beta, norm_bound)
construct(empirical_fraction, norm_bound)[source]

Update value of norm_bound.

Parameters
  • empirical_fraction (Tensor) – empirical fraction of samples with the value at most target_unclipped_quantile.

  • norm_bound (Tensor) – Clipping bound for the l2 norm of the gradients.

Returns

Tensor, generated noise with shape like given gradients.

class mindarmour.privacy.diff_privacy.ClipMechanismsFactory[source]

Factory class of clip mechanisms Wrapper of clip noise generating mechanisms. It supports Adaptive Clipping with Gaussian Random Noise for now.

For details, please check Tutorial.

static create(mech_name, decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]
Parameters
  • mech_name (str) – Clip noise generated strategy, support ‘Gaussian’ now.

  • decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: Linear.

  • learning_rate (float) – Learning rate of update norm clip. Default: 0.001.

  • target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.

  • fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is \(empirical\_fraction + N(0, fraction\_stddev)\). Default: 0.01.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

Raises

NameErrormech_name must be in [‘Gaussian’].

Returns

Mechanisms, class of noise generated Mechanism.

Examples

>>> from mindspore import Tensor
>>> from mindspore.common import dtype as mstype
>>> from mindarmour.privacy.diff_privacy import ClipMechanismsFactory
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_bound = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> clip_mechanism = ClipMechanismsFactory()
>>> ada_clip = clip_mechanism.create('Gaussian',
...                                  decay_policy=decay_policy,
...                                  learning_rate=learning_rate,
...                                  target_unclipped_quantile=target_unclipped_quantile,
...                                  fraction_stddev=beta_stddev)
>>> next_norm_bound = ada_clip(beta, norm_bound)
class mindarmour.privacy.diff_privacy.DPModel(micro_batches=2, norm_bound=1.0, noise_mech=None, clip_mech=None, **kwargs)[source]

DPModel is used for constructing a model for differential privacy training. This class is overload mindspore.train.model.Model.

For details, please check Tutorial.

Parameters
  • micro_batches (int) – The number of small batches split from an original batch. Default: 2.

  • norm_bound (float) – Use to clip the bound, if set 1, will return the original data. Default: 1.0.

  • noise_mech (Mechanisms) – The object can generate the different type of noise. Default: None.

  • clip_mech (Mechanisms) – The object is used to update the adaptive clip. Default: None.

Raises
  • ValueError – If DPOptimizer and noise_mech are both None or not None.

  • ValueError – If noise_mech or DPOtimizer’s mech method is adaptive while clip_mech is not None.

class mindarmour.privacy.diff_privacy.DPOptimizerClassFactory(micro_batches=2)[source]

Factory class of Optimizer.

Parameters

micro_batches (int) – The number of small batches split from an original batch. Default: 2.

Returns

Optimizer, Optimizer class.

Examples

>>> from mindarmour.privacy.diff_privacy import DPOptimizerClassFactory
>>> from tests.ut.python.utils.mock_net import Net
>>> network = Net()
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5)
>>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(),
...                                          learning_rate=0.001,
...                                          momentum=0.9)
create(policy)[source]

Create DP optimizer. Policies can be ‘sgd’, ‘momentum’ or ‘adam’.

Parameters

policy (str) – Choose original optimizer type.

Returns

Optimizer, an optimizer with DP.

set_mechanisms(policy, *args, **kwargs)[source]

Get noise mechanism object. Policies can be ‘sgd’, ‘momentum’ or ‘adam’. Candidate args and kwargs can be seen in class NoiseMechanismsFactory of mechanisms.py.

Parameters

policy (str) – Choose mechanism type.

class mindarmour.privacy.diff_privacy.NoiseAdaGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy='Exp')[source]

Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be ‘Time’ mode, ‘Step’ mode, ‘Exp’ mode. self._noise_multiplier will be update during model training process.

Parameters
  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

  • noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.

  • decay_policy (str) – Noise decay strategy include ‘Step’, ‘Time’, ‘Exp’. Default: ‘Exp’.

Examples

>>> from mindspore import Tensor
>>> from mindspore.common import dtype as mstype
>>> from mindarmour.privacy.diff_privacy import NoiseAdaGaussianRandom
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 1.0
>>> seed = 0
>>> noise_decay_rate = 6e-6
>>> decay_policy = "Exp"
>>> net = NoiseAdaGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_decay_rate, decay_policy)
>>> res = net(gradients)
construct(gradients)[source]

Generated Adaptive Gaussian noise.

Parameters

gradients (Tensor) – The gradients.

Returns

Tensor, generated noise with shape like given gradients.

class mindarmour.privacy.diff_privacy.NoiseGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None)[source]

Generate noise in Gaussian Distribution with \(mean=0\) and \(standard\_deviation = norm\_bound * initial\_noise\_multiplier\).

Parameters
  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0, random normal will use secure random number. If seed!=0, random normal will generate values using given seed. Default: 0.

  • decay_policy (str) – Mechanisms parameters update policy. Default: None.

Examples

>>> from mindspore import Tensor
>>> from mindspore.common import dtype as mstype
>>> from mindarmour.privacy.diff_privacy import NoiseGaussianRandom
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.1
>>> initial_noise_multiplier = 1.0
>>> seed = 0
>>> decay_policy = None
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier, seed, decay_policy)
>>> res = net(gradients)
construct(gradients)[source]

Generated Gaussian noise.

Parameters

gradients (Tensor) – The gradients.

Returns

Tensor, generated noise with shape like given gradients.

class mindarmour.privacy.diff_privacy.NoiseMechanismsFactory[source]

Factory class of noise mechanisms Wrapper of noise generating mechanisms. It supports Gaussian Random Noise and Adaptive Gaussian Random Noise for now.

For details, please check Tutorial.

static create(mech_name, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy=None)[source]
Parameters
  • mech_name (str) – Noise generated strategy, could be ‘Gaussian’ or ‘AdaGaussian’. Noise would be decayed with ‘AdaGaussian’ mechanism while be constant with ‘Gaussian’ mechanism.

  • norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.

  • initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.

  • seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.

  • noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.

  • decay_policy (str) – Mechanisms parameters update policy. If decay_policy is None, no parameters need update. Default: None.

Raises

NameErrormech_name must be in [‘Gaussian’, ‘AdaGaussian’].

Returns

Mechanisms, class of noise generated Mechanism.

Examples

>>> from mindarmour.privacy.diff_privacy import NoiseMechanismsFactory
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 1.0
>>> noise_mechanism = NoiseMechanismsFactory()
>>> clip = noise_mechanism.create('Gaussian',
...                               norm_bound=norm_bound,
...                               initial_noise_multiplier=initial_noise_multiplier)
class mindarmour.privacy.diff_privacy.PrivacyMonitorFactory[source]

Factory class of DP training’s privacy monitor. For details, please check Tutorial.

static create(policy, *args, **kwargs)[source]

Create a privacy monitor class.

Parameters
  • policy (str) – Monitor policy, ‘rdp’ and ‘zcdp’ are supported by now. If policy is ‘rdp’, the monitor will compute the privacy budget of DP training based on Renyi differential privacy theory; If policy is ‘zcdp’, the monitor will compute the privacy budget of DP training based on zero-concentrated differential privacy theory. It’s worth noting that ‘zcdp’ is not suitable for subsampling noise mechanism.

  • args (Union[int, float, numpy.ndarray, list, str]) – Parameters used for creating a privacy monitor.

  • kwargs (Union[int, float, numpy.ndarray, list, str]) – Keyword parameters used for creating a privacy monitor.

Returns

Callback, a privacy monitor.

Examples

>>> from mindarmour.privacy.diff_privacy import PrivacyMonitorFactory
>>> rdp = PrivacyMonitorFactory.create(policy='rdp', num_samples=60000, batch_size=32)
class mindarmour.privacy.diff_privacy.RDPMonitor(num_samples, batch_size, initial_noise_multiplier=1.5, max_eps=10.0, target_delta=0.001, max_delta=None, target_eps=None, orders=None, noise_decay_mode='Time', noise_decay_rate=0.0006, per_print_times=50, dataset_sink_mode=False)[source]

Compute the privacy budget of DP training based on Renyi differential privacy (RDP) theory. According to the reference below, if a randomized mechanism is said to have ε’-Renyi differential privacy of order α, it also satisfies conventional differential privacy (ε, δ) as below:

\[(ε'+\frac{log(1/δ)}{α-1}, δ)\]

For details, please check Tutorial.

Reference: Rényi Differential Privacy of the Sampled Gaussian Mechanism

Parameters
  • num_samples (int) – The total number of samples in training data sets.

  • batch_size (int) – The number of samples in a batch while training.

  • initial_noise_multiplier (Union[float, int]) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.5.

  • max_eps (Union[float, int, None]) – The maximum acceptable epsilon budget for DP training, which is used for estimating the max training epochs. ‘None’ means there is no limit to epsilon budget. Default: 10.0.

  • target_delta (Union[float, int, None]) – Target delta budget for DP training. If target_delta is set to be δ, then the privacy budget δ would be fixed during the whole training process. Default: 1e-3.

  • max_delta (Union[float, int, None]) – The maximum acceptable delta budget for DP training, which is used for estimating the max training epochs. Max_delta must be less than 1 and suggested to be less than 1e-3, otherwise overflow would be encountered. ‘None’ means there is no limit to delta budget. Default: None.

  • target_eps (Union[float, int, None]) – Target epsilon budget for DP training. If target_eps is set to be ε, then the privacy budget ε would be fixed during the whole training process. Default: None.

  • orders (Union[None, list[int, float]]) – Finite orders used for computing rdp, which must be greater than 1. The computation result of privacy budget would be different for various orders. In order to obtain a tighter (smaller) privacy budget estimation, a list of orders could be tried. Default: None.

  • noise_decay_mode (Union[None, str]) – Decay mode of adding noise while training, which can be None, ‘Time’, ‘Step’ or ‘Exp’. Default: ‘Time’.

  • noise_decay_rate (float) – Decay rate of noise while training. Default: 6e-4.

  • per_print_times (int) – The interval steps of computing and printing the privacy budget. Default: 50.

  • dataset_sink_mode (bool) – If True, all training data would be passed to device(Ascend) one-time. If False, training data would be passed to device after each step training. Default: False.

Examples

>>> from mindarmour.privacy.diff_privacy import PrivacyMonitorFactory
>>> rdp = PrivacyMonitorFactory.create(policy='rdp', num_samples=100, batch_size=32)
max_epoch_suggest()[source]

Estimate the maximum training epochs to satisfy the predefined privacy budget.

Returns

int, the recommended maximum training epochs.

step_end(run_context)[source]

Compute privacy budget after each training step.

Parameters

run_context (RunContext) – Include some information of the model.

class mindarmour.privacy.diff_privacy.ZCDPMonitor(num_samples, batch_size, initial_noise_multiplier=1.5, max_eps=10.0, target_delta=0.001, noise_decay_mode='Time', noise_decay_rate=0.0006, per_print_times=50, dataset_sink_mode=False)[source]

Compute the privacy budget of DP training based on zero-concentrated differential privacy theory (zcdp). According to the reference below, if a randomized mechanism is said to have ρ-zCDP, it also satisfies conventional differential privacy (ε, δ) as below:

\[(ρ+2\sqrt{ρ*log(1/δ)}, δ)\]

It should be noted that ZCDPMonitor is not suitable for subsampling noise mechanisms(such as NoiseAdaGaussianRandom and NoiseGaussianRandom). The matching noise mechanism of ZCDP will be developed in the future.

For details, please check Tutorial.

Reference: Concentrated Differentially Private Gradient Descent with Adaptive per-Iteration Privacy Budget

Parameters
  • num_samples (int) – The total number of samples in training data sets.

  • batch_size (int) – The number of samples in a batch while training.

  • initial_noise_multiplier (Union[float, int]) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.5.

  • max_eps (Union[float, int]) – The maximum acceptable epsilon budget for DP training, which is used for estimating the max training epochs. Default: 10.0.

  • target_delta (Union[float, int]) – Target delta budget for DP training. If target_delta is set to be δ, then the privacy budget δ would be fixed during the whole training process. Default: 1e-3.

  • noise_decay_mode (Union[None, str]) – Decay mode of adding noise while training, which can be None, ‘Time’, ‘Step’ or ‘Exp’. Default: ‘Time’.

  • noise_decay_rate (float) – Decay rate of noise while training. Default: 6e-4.

  • per_print_times (int) – The interval steps of computing and printing the privacy budget. Default: 50.

  • dataset_sink_mode (bool) – If True, all training data would be passed to device(Ascend) one-time. If False, training data would be passed to device after each step training. Default: False.

Examples

>>> from mindarmour.privacy.diff_privacy import PrivacyMonitorFactory
>>> zcdp = PrivacyMonitorFactory.create(policy='zcdp',
...                                     num_samples=100,
...                                     batch_size=32,
...                                     initial_noise_multiplier=1.5)
max_epoch_suggest()[source]

Estimate the maximum training epochs to satisfy the predefined privacy budget.

Returns

int, the recommended maximum training epochs.

step_end(run_context)[source]

Compute privacy budget after each training step.

Parameters

run_context (RunContext) – Include some information of the model.