mindspore.lazy_inline

查看源文件
mindspore.lazy_inline(fn=None, attrs=None)[源代码]

指定一个cell是可复用的。该cell在前端编译为可复用的子图,后端根据策略内联。 注册此装饰器到cell的内置函数 __init__ 时,此装饰器会按照 attrs 的值去添加 __init__ 函数对应的入参作为cell的属性。

警告

该特性仅支持Ascend,其它硬件不支持。

参数:
  • fn (function) - cell的 __init__ 函数。

  • attrs (Union[list[string], string]) - cell需要添加的属性列表。

返回:

function,原始函数。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.nn as nn
>>> from mindspore import lazy_inline
>>> from mindspore import context
>>> from mindspore import ops
>>> def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'):
...     return nn.Conv2d(in_channels, out_channels,
...                      kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode)
...
>>> def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'):
...     return nn.Conv2d(in_channels, out_channels,
...                      kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode)
...
>>> class Block(nn.Cell):
...     expansion = 4
...
...     @lazy_inline
...     def __init__(self,
...                  in_channels,
...                  out_channels,
...                  stride=1,
...                  down_sample=False):
...         super(Block, self).__init__()
...
...         out_chls = out_channels
...         self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
...         self.bn1 = nn.BatchNorm2d(out_chls)
...
...         self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
...         self.bn2 = nn.BatchNorm2d(out_chls)
...
...         self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
...         self.bn3 = nn.BatchNorm2d(out_channels)
...
...         self.relu = nn.ReLU()
...         self.downsample = down_sample
...
...         self.conv_down_sample = conv1x1(in_channels, out_channels,
...                                         stride=stride, padding=0)
...         self.bn_down_sample = nn.BatchNorm2d(out_channels)
...         self.add = ops.Add()
...
...     def construct(self, x):
...         identity = x
...
...         out = self.conv1(x)
...         out = self.bn1(out)
...         out = self.relu(out)
...
...         out = self.conv2(out)
...         out = self.bn2(out)
...         out = self.relu(out)
...
...         out = self.conv3(out)
...         out = self.bn3(out)
...
...         if self.downsample:
...             identity = self.conv_down_sample(identity)
...             identity = self.bn_down_sample(identity)
...
...         out = self.add(out, identity)
...         out = self.relu(out)
...
...         return out
...
>>> class Net(nn.Cell):
...     def __init__(self, block, num_classes=100):
...         super(Net, self).__init__()
...
...         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
...         self.bn1 = nn.BatchNorm2d(64)
...         self.relu = nn.ReLU()
...         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
...
...         self.layer = self.MakeLayer(
...             block, 50, in_channels=64, out_channels=2048, stride=2)
...         self.avgpool = nn.AvgPool2d(7, 1)
...         self.flatten = ops.Flatten()
...
...     def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
...         layers = []
...         resblk = block(in_channels, out_channels,
...                        stride=stride, down_sample=True)
...         layers.append(resblk)
...
...         for _ in range(1, layer_num):
...             resblk = block(out_channels, out_channels, stride=1)
...             layers.append(resblk)
...
...         return nn.SequentialCell(layers)
...
...     def construct(self, x):
...         x = self.conv1(x)
...         x = self.bn1(x)
...         x = self.relu(x)
...         x = self.maxpool(x)
...         x = self.layer(x)
...         x = self.avgpool(x)
...         x = self.flatten(x)
...         return x
...
>>> def test_compile():
...     net = Net(Block)
...     inp = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32))
...     net(inp)
...
>>> context.set_context(mode=context.GRAPH_MODE,
...                     save_graphs=True, save_graphs_path="./lazy")
...
>>> test_compile()