mindspore_gs.pruner.PrunerFtCompressAlgo

class mindspore_gs.pruner.PrunerFtCompressAlgo(config=None)[source]

PrunerFtCompressAlgo is a subclass of CompAlgo that implements the ability to remove redundant convolution kernels and fully train the network.

Parameters

config (dict) –

Configuration of PrunerFtCompressAlgo, keys are attribute names, values are attribute values. Supported attribute are listed below:

  • prune_rate (float): number in [0.0, 1.0)

Raises
  • TypeError – If prune_rate is not float.

  • ValueError – If epoch_size is less than 0 or greater than or equal to 1.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore_gs.pruner import PrunerKfCompressAlgo, PrunerFtCompressAlgo
>>> from mindspore import nn
>>> class Net(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(Net, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
... class NetToPrune(nn.Cell):
...     def __init__(self):
...        super(NetToPrune, self).__init__()
...        self.layer = Net()
...
...     def construct(self, x):
...         x = self.layer(x)
...         return x
...
>>> net = NetToPrune()
>>> kf_pruning = PrunerKfCompressAlgo({})
>>> net_pruning_kf = kf_pruning.apply(net)
>>> ## 1) Define FineTune Algorithm
>>> ft_pruning = PrunerFtCompressAlgo({'prune_rate': 0.5})
>>> ## 2) Apply FineTune-algorithm to origin network
>>> net_pruning_ft = ft_pruning.apply(net_pruning_kf)
>>> ## 3) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning_ft)
NetToPrune<
 (layer): Net<
  (conv): MaskedConv2dbn<
    (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
      padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
      (name=conv.bn.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
      (name=conv.bn.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
      (name=conv.bn.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
      (name=conv.bn.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
    >
  (bn): SequentialCell<>
  >
>
apply(network)[source]

Transform a knockoff network to a normal and pruned network.

Parameters

network (Cell) – Knockoff network.

Returns

Pruned network.

Raises

TypeError – If network is not Cell.

set_prune_rate(prune_rate: float)[source]

Set value of prune_rate of _config

Parameters

prune_rate (float) – the size of network needs to be pruned.

Raises
  • TypeError – If prune_rate is not float.

  • ValueError – If prune_rate is less than 0. or greater than 1.