mindarmour.privacy.sup_privacy.mask_monitor.masker 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Masker module of suppress-based privacy..
"""
from mindspore.train.callback import Callback
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_param_type
from mindarmour.privacy.sup_privacy.train.model import SuppressModel
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl

LOGGER = LogUtil.get_instance()
TAG = 'suppress masker'


[文档]class SuppressMasker(Callback): """ Periodicity check suppress privacy function status and toggle suppress operation. For details, please check `Protecting User Privacy with Suppression Privacy <https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_suppress_privacy.html>`_. Args: model (SuppressModel): SuppressModel instance. suppress_ctrl (SuppressCtrl): SuppressCtrl instance. Examples: >>> import mindspore.nn as nn >>> import mindspore as ms >>> from mindspore import set_context, ops >>> from mindspore.nn import Accuracy >>> from mindarmour.privacy.sup_privacy import SuppressModel >>> from mindarmour.privacy.sup_privacy import SuppressMasker >>> from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory >>> from mindarmour.privacy.sup_privacy import MaskLayerDes >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self._softmax = ops.Softmax() ... self._Dense = nn.Dense(10,10) ... self._squeeze = ops.Squeeze(1) ... def construct(self, inputs): ... out = self._softmax(inputs) ... out = self._Dense(out) ... return self._squeeze(out) >>> set_context(mode=ms.PYNATIVE_MODE, device_target="GPU") >>> network = Net() >>> masklayers = [] >>> masklayers.append(MaskLayerDes("_Dense.weight", 0, False, True, 10)) >>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=network, ... mask_layers=masklayers, ... policy="local_train", ... end_epoch=10, ... batch_num=1, ... start_epoch=3, ... mask_times=10, ... lr=0.05, ... sparse_end=0.95, ... sparse_start=0.0) >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") >>> net_opt = nn.SGD(network.trainable_params(), 0.05) >>> model_instance = SuppressModel(network=network, ... loss_fn=net_loss, ... optimizer=net_opt, ... metrics={"Accuracy": Accuracy()}) >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) >>> masker_instance = SuppressMasker(model_instance, suppress_ctrl_instance) """ def __init__(self, model, suppress_ctrl): super(SuppressMasker, self).__init__() self._model = check_param_type('model', model, SuppressModel) self._suppress_ctrl = check_param_type('suppress_ctrl', suppress_ctrl, SuppressCtrl)
[文档] def step_end(self, run_context): """ Update mask matrix tensor used for SuppressModel instance. Args: run_context (RunContext): Include some information of the model. """ cb_params = run_context.original_args() cur_step = cb_params.cur_step_num cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if self._suppress_ctrl is not None and self._model.network_end is not None: if not self._suppress_ctrl.mask_initialized: raise ValueError("Not initialize network!") if cur_step_in_epoch % 100 == 1: self._suppress_ctrl.calc_theoretical_sparse_for_conv() _, _, _ = self._suppress_ctrl.calc_actual_sparse_for_conv( self._suppress_ctrl.networks) self._suppress_ctrl.update_status(cb_params.cur_epoch_num, cur_step, cur_step_in_epoch) if self._suppress_ctrl.to_do_mask: self._suppress_ctrl.update_mask(self._suppress_ctrl.networks, cur_step) LOGGER.info(TAG, "suppress update") elif not self._suppress_ctrl.to_do_mask and self._suppress_ctrl.mask_started: self._suppress_ctrl.reset_zeros()