mindspore.ops.MultitypeFuncGraph

View Source On Gitee
class mindspore.ops.MultitypeFuncGraph(name, read_value=False)[source]

MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. Initialize an MultitypeFuncGraph object with name, and use register with input types as the decorator for the function to be registered. And the object can be called with different types of inputs, and work with HyperMap and Map.

Parameters
  • name (str) – Operator name.

  • read_value (bool, optional) – If the registered function do not need to set value on Parameter, and all inputs will pass by value, set read_value to True . Default: False .

Raises

ValueError – If failed to find a matching function for the given arguments.

Supported Platforms:

Ascend GPU CPU

Examples

>>> # `add` is a metagraph object which will add two objects according to
>>> # input type using ".register" decorator.
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore import dtype as mstype
>>> import mindspore.ops as ops
>>>
>>> tensor_add = ops.Add()
>>> add = ops.MultitypeFuncGraph('add')
>>> @add.register("Number", "Number")
... def add_scala(x, y):
...     return x + y
>>> @add.register("Tensor", "Tensor")
... def add_tensor(x, y):
...     return tensor_add(x, y)
>>> output = add(1, 2)
>>> print(output)
3
>>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32))
>>> print(output)
[0.2 1.2 2.4]
register(*type_names)[source]

Register a function for the given type string.

Parameters

type_names (Union[str, mindspore.dtype]) – Inputs type names or types list.

Returns

decorator, a decorator to register the function to run, when called under the types described in type_names.