Calling the Custom Class

View Source On Gitee

Overview

In static graph mode, using jit_class to decorate a custom class, users can create and call the instance of this custom class, and obtain attributes and methods for that custom class.

jit_class is applied to static graph mode, expanding the scope of support for improving static graph compilation syntax. In dynamic graph mode, that is, PyNative mode, the use of jit_class does not affect the execution logic of the PyNative mode.

This document describes how to use jit_class decorator so that you can use jit_class decorator more effectively.

jit_class Decorates Custom Class

After decorating a custom class with @jit_class, you can create and call the instance of the custom class and obtain the attributes and methods.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    value = ms.Tensor(np.array([1, 2, 3]))

class Net(nn.Cell):
    def construct(self):
        return InnerNet().value

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]

jit_class supports nesting use of the custom class, nesting uses scenarios of custom classes and nn. Cell. It should be noted that when a class inherits, if the parent class is decorated with jit_class, the subclass will also have the ability to jit_class.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class Inner:
    def __init__(self):
        self.value = ms.Tensor(np.array([1, 2, 3]))

@ms.jit_class
class InnerNet:
    def __init__(self):
        self.inner = Inner()

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self):
        out = self.inner_net.inner.value
        return out

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]

jit_class only support decorating custom classes, not nn. Cell and nonclass types. If you execute the following use case, an error will appear.

import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class Net(nn.Cell):
    def construct(self, x):
        return x

ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(1)
net = Net()
net(x)

The error information is as follows:

TypeError: Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.
import mindspore as ms

@ms.jit_class
def func(x, y):
    return x + y

func(1, 2)

The error information is as follows:

TypeError: Decorator jit_class can only be used for class type, but got <function func at 0x7fee33c005f0>.

Obtaining the Attributes and Methods of the Custom Class

Support calling a class’s attributes by class name, not calling a class’s methods by class name. For instances of a class, support calling its attributes and methods.

import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self, val):
        self.number = val

    def act(self, x, y):
        return self.number * (x + y)

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet(2)

    def construct(self, x, y):
        return self.inner_net.number + self.inner_net.act(x, y)

ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)
12

Calling private attributes and magic methods is not supported, and the method functions that are called must be within the syntax supported by static graph compilation. If you execute the following use case, an error will appear.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self):
        self.value = ms.Tensor(np.array([1, 2, 3]))

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.inner_net = InnerNet()

    def construct(self):
        out = self.inner_net.__str__()
        return out

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()

The error information is as follows:

RuntimeError: `__str__` is a private variable or magic method, which is not supported.

Creating Instance of the Custom Class

In the static graph mode, when you create the instance of the custom class, the parameter requirement is a constant.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self, val):
        self.number = val + 3

class Net(nn.Cell):
    def construct(self):
        net = InnerNet(2)
        return net.number

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
5

For other scenarios, when creating an instance of a custom class, there is a restriction that no parameters must be constants. For example, the following use case:

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self, val):
        self.number = val + 3

class Net(nn.Cell):
    def __init__(self, val):
        super(Net, self).__init__()
        self.inner = InnerNet(val)

    def construct(self):
        return self.inner.number

ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
net = Net(x)
out = net()
print(out)
5

Calling the Instance of the Custom Class

When you call an instance of a custom class, the __call__ function method of that class is called.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self, number):
        self.number = number

    def __call__(self, x, y):
        return self.number * (x + y)

class Net(nn.Cell):
    def construct(self, x, y):
        net = InnerNet(2)
        out = net(x, y)
        return out

ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)
10

If the class does not define the __call__ function, an error message will be reported. If you execute the following use case, an error will appear.

import numpy as np
import mindspore.nn as nn
import mindspore as ms

@ms.jit_class
class InnerNet:
    def __init__(self, number):
        self.number = number

class Net(nn.Cell):
    def construct(self, x, y):
        net = InnerNet(2)
        out = net(x, y)
        return out

ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)

The error information is as follows:

RumtimeError: MsClassObject: 'InnerNet' has no `__call__` function, please check the code.