调用自定义类

在线运行下载Notebook下载样例代码查看源文件

概述

通过ms_class,用户可以在静态图模式下调用自定义类的属性和方法。

在静态图模式下,用户需要获取自定义类的属性/方法时,可以对该类使用@ms_class装饰器,从而调用其属性/方法。在动态图模式即PyNative模式下,ms_class的使用也是支持的,但用户不需要@ms_class装饰器也能调用自定义类的属性和方法。

本文档主要介绍ms_class的使用场景和使用须知,以便您可以更有效地使用ms_class功能。

使用场景

1、调用自定义类的属性

调用自定义类的属性时,可以通过@ms_class装饰器,对自定义类进行修饰。

[1]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class

@ms_class
class InnerNet:
    def __init__(self):
        self.value = Tensor(np.array([1, 2, 3]))

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self):
        out = self.inner_net.value
        return out

context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]

2、调用自定义类的方法

调用自定义类的方法时,可以通过@ms_class装饰器,对自定义类进行修饰。

[2]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class

@ms_class
class InnerNet:
    def act(self, x, y):
        return x + y

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self, x, y):
        out = self.inner_net.act(x, y)
        return out

context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([4, 5, 6]).astype(np.int32))
net = Net()
out = net(x, y)
print(out)
[5 7 9]

3、调用嵌套的自定义类的属性和方法

多个自定义类嵌套时,如果都使用了@ms_class装饰器,则可以获取嵌套类的属性和方法。

[3]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class

@ms_class
class Inner:
    def __init__(self):
        self.value = Tensor(np.array([1, 2, 3]))

@ms_class
class InnerNet:
    def __init__(self):
        self.inner = Inner()

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self):
        out = self.inner_net.inner.value
        return out

context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]

4、自定义类和nn.Cell嵌套使用

当自定义类和nn.Cell嵌套使用时,调用自定义类的属性和方法。关于nn.Cell的介绍,请参考mindspore.nn.Cell

[4]:
import numpy as np
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore import context, Tensor, ms_class

class Net(nn.Cell):
    def __init__(self, val):
        super().__init__()
        self.val = val

    def construct(self, x):
        return x + self.val

@ms_class
class TrainNet():
    class Loss(nn.Cell):
        def __init__(self, net):
            super().__init__()
            self.net = net

        def construct(self, x):
            out = self.net(x)
            return out * 2

    def __init__(self, net):
        self.net = net
        loss_net = self.Loss(self.net)
        self.number = loss_net(10)

global_net = Net(1)
class LearnNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.value = TrainNet(global_net).number

    def construct(self, x):
        return x + self.value


context.set_context(mode=context.GRAPH_MODE)
x = Tensor(3, mstype.int32)
leanrn_net = LearnNet()
out = leanrn_net(x)
print(out)
25

使用须知

使用ms_class时,需要考虑以下条件:

1、ms_class不支持非class类型

from mindspore import ms_class

@ms_class
def func(x, y):
    return x + y

func(1, 2)

执行代码后,将会提示以下报错信息:

TypeError: Decorator ms_class can only be used for class type, but got <function func at 0x7fee33c005f0>.

2、ms_class支持调用类实例的属性和方法,不支持直接从类定义获取其属性和方法,不支持在construct/ms_function函数中创建自定义类的实例。

import mindspore.nn as nn
from mindspore import context, ms_class

@ms_class
class InnerNet:
    def __init__(self):
        self.number = 1

class Net(nn.Cell):
    def construct(self):
        out = InnerNet().number
        return out

context.set_context(mode=context.GRAPH_MODE)
net = Net()
net()

执行代码后,将会提示以下报错信息:

ValueError: This may be not defined, or it can’t be a operator. Please check code.

3、不支持对nn.Cell使用@ms_class装饰器。

import mindspore.nn as nn
from mindspore import context, Tensor, ms_class

@ms_class
class Net(nn.Cell):
    def construct(self, x):
        return x

context.set_context(mode=context.GRAPH_MODE)
x = Tensor(1)
net = Net()
net(x)

执行代码后,将会提示以下报错信息:

TypeError: ms_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.

4、不支持调用自定义类的私有属性或魔术方法。

import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class

@ms_class
class InnerNet:
    def __init__(self):
        self.value = Tensor(np.array([1, 2, 3]))

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self):
        out = self.inner_net.__str__()
        return out

context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()

执行代码后,将会提示以下报错信息:

AttributeError: __str__ is a private variable or magic method, which is not supported.