mindspore_gl.parser.vcg 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Translation."""
import ast
import inspect
from types import MethodType
from textwrap import dedent
from ast_decompiler import decompile
from .infer_expr_type_pass import InferExprTypePass
from .check_syntax_pass import CheckSyntaxPass
from .ast_rewriter import AstRewriter
from .code_comparator import CodeComparator
from .utils import src_to_function

SCREEN_WIDTH = 200
DISPLAY = True


def set_display_config(screen_width, display):
    """
    Set screen width and display configure used for translate function.

    Args:
        screen_width (int): Determines the screen width on which the code is displayed.
        display (bool): Show code comparison or Not.
    """
    global SCREEN_WIDTH, DISPLAY
    SCREEN_WIDTH = screen_width
    DISPLAY = display


[文档]def translate(obj, method_name: str, translate_path: None or str = None): """ Translate the vertex central code into MindSpore understandable code. After translation, a new function will generate in `/.mindspore_gl` . The origin method will be replaced with this function. Args: obj (Object): The object. method_name (str): The name of the method to be translated. translate_path (str): The path for save the construct file. Default: None. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import mindspore as ms >>> import mindspore.ops as ops >>> from mindspore_gl.nn import GNNCell >>> from mindspore_gl import BatchedGraph >>> from mindspore_gl.parser.vcg import translate ... >>> class Net(GNNCell): ... def __init__(self): ... super().__init__() ... translate(self, "loss") ... ... def construct(self, pred, label, g: BatchedGraph): ... loss = self.loss(pred, label, g) ... loss = loss * g.graph_mask ... return loss ... ... def loss(self, pred, label, g: BatchedGraph): ... criterion = ms.nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') ... loss = criterion(pred, label) ... loss = ops.ReduceMean()(loss * g.graph_mask) ... return loss """ global SCREEN_WIDTH, DISPLAY fn = getattr(obj, method_name) src = inspect.getsource(fn) src = dedent(src) py_ast = ast.parse(src) syntax_checker = CheckSyntaxPass(fn.__globals__) ret = syntax_checker.analyze(py_ast) type_inferer = InferExprTypePass(ret, src) ret = type_inferer.analyze(py_ast) if DISPLAY: comparator = CodeComparator(SCREEN_WIDTH) comparator.record_origin_lineno(py_ast) rewriter = AstRewriter(ret) new_ast = rewriter.visit(py_ast) if DISPLAY: comparator.mapping_by_origin_lineno(new_ast) comparator.show_diff() new_src = decompile(new_ast) new_fn = src_to_function(new_src, method_name, fn.__globals__, translate_path) new_fn.__module__ = fn.__module__ setattr(obj, method_name, MethodType(new_fn, obj))