mindspore.nn.GraphCell

查看源文件
class mindspore.nn.GraphCell(graph, params_init=None, obf_random_seed=None)[源代码]

运行从MindIR加载的计算图。

此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。

参数:
  • graph (FuncGraph) - 从MindIR加载的编译图。

  • params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: None

  • obf_random_seed (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 mindspore.obfuscate_model() 。如果导入的 graph 是一个经过混淆的模型,那么须提供 obf_random_seedobf_random_seed 的取值范围是(0, 9223372036854775807]。默认值: None

异常:
  • TypeError - 如果图不是FuncGraph类型。

  • TypeError - 如果 params_init 不是字典。

  • TypeError - 如果 params_init 的key不是字符串。

  • TypeError - 如果 params_init 的value既不是 Tensor也不是Parameter。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore import context
>>> context.set_context(mode=context.GRAPH_MODE)
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> ms.export(net, input, file_name="net", file_format="MINDIR")
>>> graph = ms.load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
>>> print(output)
[[[[4. 6. 4.]
   [6. 9. 6.]
   [4. 6. 4.]]]]