mindspore.ops.silent_check.ASDBase

View Source On Gitee
class mindspore.ops.silent_check.ASDBase(cls, *args, **kwargs)[source]

ASDBase is the base class of operator with accuracy-sensitive detection feature in python.

Parameters
  • cls (Primitive) – Original operator requiring accuracy-sensitive detection feature.

  • args (tuple) – A variable parameter tuple to the original operator.

  • kwargs (dict) – A variable parameter dictionary passed the original operator.

Supported Platforms:

Ascend

Examples

>>> from mindspore.ops.silent_check import ASDBase
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
>>> class LayerNormASD(ASDBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(OriginLayerNorm, *args, **kwargs)
...         # init parameters for accuracy-sensitive detection by calling the base class method generate_params()
...         self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
...
...     def __call__(self, input_x, gamma, beta):
...         if self.enable_check:
...             # execute accuracy-sensitive detection by calling the check_op of base class
...             input_x = self.check_op(
...                 input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None)
...             self.cnt += 1
...         # return the result of original operator
...         return self.op(input_x, gamma, beta)
generate_params()[source]

Generate support params for accuracy-sensitive detection.

Returns

tuple consisting of four elements. The derived class initializes the parameters required for accuracy-sensitive detection by calling this function.

Examples

>>> from mindspore.ops.silent_check import ASDBase
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
>>> class LayerNormASD(ASDBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(OriginLayerNorm, *args, **kwargs)
...         # init parameters for accuracy-sensitive detection by calling the base class function
...         self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()