Source code for mindspore.profiler.profiling

# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Profiling api file."""
import os
import stat
import time
import json

from mindspore import log as logger, context
from mindspore.communication.management import GlobalComm, get_rank, get_group_size
import mindspore._c_expression as c_expression
import mindspore._c_dataengine as cde
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \
    ProfilerIOException, ProfilerException, ProfilerRawFileException
from mindspore.profiler.common.exceptions.exceptions import ProfilerPathErrorException
from mindspore.profiler.common.exceptions.exceptions import ProfilerDirNotFoundException
from mindspore.profiler.common.util import get_file_path, fwrite_format
from mindspore.profiler.common.validator.validate_path import \
    validate_and_normalize_path
from mindspore.profiler.parser.aicpu_data_parser import DataPreProcessParser
from mindspore.profiler.parser.framework_parser import FrameworkParser, GpuFrameWorkParser, DynamicFrameWorkParser
from mindspore.profiler.parser.hwts_log_parser import HWTSLogParser
from mindspore.profiler.parser.integrator import Integrator, DeviceTarget
from mindspore.profiler.parser.integrator import GpuTimelineGenerator, CpuTimelineGenerator, AscendTimelineGenerator
from mindspore.profiler.parser.memory_usage_parser import MemoryUsageParser
from mindspore.profiler.parser.minddata_parser import MinddataParser
from mindspore.profiler.parser.minddata_analyzer import MinddataProfilingAnalyzer
from mindspore.profiler.parser.flops_parser import FlopsParser
from mindspore.profiler.parser.minddata_pipeline_parser import MinddataPipelineParser
from mindspore.profiler.parser.optime_parser import OPComputeTimeParser
from mindspore.profiler.parser.step_trace_parser import GpuStepTraceParser, AscendStepTraceParser
from mindspore.profiler.parser.hccl_parser import HcclParser
from mindspore.profiler.parser.op_intermediate_parser import OPIntermediateParser

INIT_OP_NAME = 'Default/InitDataSetQueue'


def _environment_check():
    if c_expression.security.enable_security():
        raise RuntimeError("Profiler is not supported when MindSpore is compiled with \'-s on\'.")


[文档]class Profiler: """ This class to enable the profiling of MindSpore neural networks. MindSpore users can import the mindspore.Profiler, initialize the Profiler object to start profiling, and use Profiler.analyse() to stop profiling and analyse the results. Users can visualize the results using the MindInsight tool. Now, Profiler supports AICORE operator, AICPU operator, HostCPU operator, memory, correspondence, cluster, etc data analysis. Args: output_path (str, optional): Output data path. Default: "./data". profile_communication (bool, optional): (Ascend only) Whether to collect communication performance data in a multi devices training,collect when True. Setting this parameter has no effect during single device training. Default: False. profile_memory (bool, optional): (Ascend only) Whether to collect tensor memory data, collect when True. Default: False. start_profile (bool, optional): The start_profile parameter controls whether to enable or disable performance data collection based on conditions. Default: True. Raises: RuntimeError: When the version of CANN does not match the version of MindSpore, MindSpore cannot parse the generated ascend_job_id directory structure. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn >>> import mindspore.dataset as ds >>> from mindspore import Profiler >>> >>> >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.fc = nn.Dense(2,2) ... def construct(self, x): ... return self.fc(x) >>> >>> def generator(): ... for i in range(2): ... yield (np.ones([2, 2]).astype(np.float32), np.ones([2]).astype(np.int32)) >>> >>> def train(net): ... optimizer = nn.Momentum(net.trainable_params(), 1, 0.9) ... loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) ... data = ds.GeneratorDataset(generator, ["data", "label"]) ... model = ms.Model(net, loss, optimizer) ... model.train(1, data) >>> >>> if __name__ == '__main__': ... # If the device_target is GPU, set the device_target to "GPU" ... ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") ... ... # Init Profiler ... # Note that the Profiler should be initialized before model.train ... profiler = Profiler() ... ... # Train Model ... net = Net() ... train(net) ... ... # Profiler end ... profiler.analyse() """ _hwts_output_filename_target = "output_format_data_hwts_" _opcompute_output_filename_target = "output_op_compute_time_" _aicpu_op_output_filename_target = "output_data_preprocess_aicpu_" _has_analysed = False _has_initialized = False _ascend_profiling_options = "" _ascend_job_id = "" def __init__(self, **kwargs): if Profiler._has_initialized: msg = "Do not init twice in the profiler." raise RuntimeError(msg) Profiler._has_initialized = True self._dev_id = None self._cpu_profiler = None self._gpu_profiler = None self._md_profiler = None self._init_time = None self._ascend_job_id = '' self._job_id_env = None self._filt_optype_names = '' self._output_path = '' self._rank_size = 0 self._ascend_profiler = None _environment_check() # get device_id and device_target self._get_devid_rankid_and_devtarget() self._get_output_path(kwargs) self._profile_communication = False self._has_started = False self._has_started_twice = False self.start_profile = True self._profile_memory = False self._stop_time = 0 self._ascend_dynamic_status = False self._cpu_dynamic_status = False self._gpu_dynamic_status = False self._decide_device_target(kwargs) if self.start_profile: self.start() @staticmethod def _parse_host_start_log(input_file): """ Parse host start log file, get the start time of the job. Args: input_file (str): The file path of the host start log file. Returns: str, job start time. """ job_start_time = "" with open(input_file) as f: for line in f.readlines(): if "clock_realtime" in line: # 16 means the first digit of the timestamp, len(line)-3 means the last. job_start_time = line[16:len(line) - 3] return job_start_time @staticmethod def _check_output_path(output_path): """Checking path validity.""" try: output_path = validate_and_normalize_path(output_path) except RuntimeError as e: raise ProfilerPathErrorException(f'profiling data output path {output_path} is invalid.') from e finally: pass if not os.path.isdir(output_path): raise ProfilerDirNotFoundException(output_path) return output_path
[文档] def op_analyse(self, op_name, device_id=None): """ Profiler users can use this interface to obtain operator performance data. Args: op_name (str or list): The primitive operator name to query. device_id (int, optional): ID of the target device. This parameter is optional during network training or inference, and users can use device_id parameter to specify which card operator performance data to parse. If this interface is used for offline data parsing, Default: 0. Raises: TypeError: If the op_name parameter type is incorrect. TypeError: If the device_id parameter type is incorrect. RunTimeError: If MindSpore runs on Ascend, this interface cannot be used. Supported Platforms: ``GPU`` ``CPU`` Examples: >>> from mindspore import Profiler >>> ... # Profiler init. ... profiler = Profiler() ... ... # Train Model or eval Model. ... net = Net() ... train(net) ... ... # Profiler end ... profiler.analyse() ... ... profiler.op_analyse(op_name=["BiasAdd", "Conv2D"]) Examples: >>> from mindspore import Profiler >>> ... # Profiler init. ... profiler = Profiler(output_path="my_profiler_path") ... profiler.op_analyse(op_name="Conv2D") """ if self._device_target == 'ascend': raise RuntimeError("The Interface 'Profiler.op_analyse()' is not supported on Ascend currently.") if device_id and not isinstance(device_id, int): raise TypeError(f"For 'Profiler.op_analyse()', the parameter device_id must be int, " f"but got type {type(device_id)}") online_device_id = int(self._dev_id) self._dev_id = self._dev_id if device_id is None else device_id if self._dev_id is None: self._dev_id = 0 if not isinstance(op_name, str) and not isinstance(op_name, list): raise TypeError(f"For 'Profiler.op_analyse()', the parameter op_name must be str or list, " f"but got type {type(op_name)}") if not op_name: raise TypeError(f"For 'Profiler.op_analyse()', the parameter op_name cannot be "", '' or [].") parser = GpuFrameWorkParser(self._output_path, self._dev_id, op_name) op_info = parser.parse() if self._rank_size > 1: if online_device_id == int(self._dev_id): return op_info if online_device_id != int(self._dev_id): message = f"For 'Profiler.op_analyse()', the parameter device_id is equal to {self._dev_id}, but the " \ f"current device id is {online_device_id}, so no operator performance information is queried." return message return op_info
[文档] def start(self): """ Used for Ascend, GPU, start profiling. Profiling can be turned on based on step and epoch. Raises: RuntimeError: If the profiler has already started. RuntimeError: If MD profiling has stopped, repeated start action is not supported. RuntimeError: If the start_profile parameter is not set or is set to True. Examples: >>> class StopAtStep(Callback): >>> def __init__(self, start_step, stop_step): ... super(StopAtStep, self).__init__() ... self.start_step = start_step ... self.stop_step = stop_step ... self.profiler = Profiler(start_profile=False) ... >>> def step_begin(self, run_context): ... cb_params = run_context.original_args() ... step_num = cb_params.cur_step_num ... if step_num == self.start_step: ... self.profiler.start() ... >>> def step_end(self, run_context): ... cb_params = run_context.original_args() ... step_num = cb_params.cur_step_num ... if step_num == self.stop_step: ... self.profiler.stop() ... >>> def end(self, run_context): ... self.profiler.analyse() """ if not self.start_profile and context.get_context("mode") == context.PYNATIVE_MODE: raise RuntimeError("Pynative model does not support conditional collection of performance data.") self._start_time = int(time.time() * 10000000) logger.info("Profiling: start time: %d", self._start_time) if not self._has_started: if not self._has_started_twice: self._has_started = True self._has_started_twice = True else: raise RuntimeError("MindSpore Profiling has finished, repeated start and stop actions are not " "supported.") else: raise RuntimeError("The profiler has already started. Use profiler.start() only when start_profile value " "is set to False.") # No need to start anything if parse profiling data offline if self._is_offline_parser(): return self._cpu_profiler.step_profiling_enable(True) if self._device_target and self._device_target == DeviceTarget.GPU.value: self._md_profiler.start() self._gpu_profiler.step_profiling_enable(True) elif self._device_target and self._device_target == DeviceTarget.ASCEND.value: self._md_profiler.start() if context.get_context("mode") == context.PYNATIVE_MODE: self._ascend_pynative_start() else: self._ascend_graph_start()
[文档] def stop(self): """ Used for Ascend, GPU, stop profiling. Profiling can be turned off based on step and epoch. Raises: RuntimeError: If the profiler has not started, this function is disabled. Examples: >>> class StopAtEpoch(Callback): >>> def __init__(self, start_epoch, stop_epoch): ... super(StopAtEpoch, self).__init__() ... self.start_epoch = start_epoch ... self.stop_epoch = stop_epoch ... self.profiler = Profiler(start_profile=False) ... >>> def epoch_begin(self, run_context): ... cb_params = run_context.original_args() ... epoch_num = cb_params.cur_epoch_num ... if epoch_num == self.start_epoch: ... self.profiler.start() ... >>> def epoch_end(self, run_context): ... cb_params = run_context.original_args() ... epoch_num = cb_params.cur_epoch_num ... if epoch_num == self.stop_epoch: ... self.profiler.stop() ... >>> def end(self, run_context): ... self.profiler.analyse() """ if self._has_started: self._has_started = False else: raise RuntimeError("The profiler has not started, so can not stop. Please call the start() method " "before calling the stop() method.") # No need to stop anything if parse profiling data offline if self._is_offline_parser(): return self._md_profiler.stop() self._md_profiler.save(self._output_path) if self._device_target and self._device_target == DeviceTarget.GPU.value: self._gpu_profiler.stop() elif self._device_target and self._device_target == DeviceTarget.ASCEND.value: if context.get_context("mode") == context.PYNATIVE_MODE: self._pynative_profiler.stop() self._ascend_profiler.stop() self._stop_time = int(time.time() * 10000000) logger.info("Profiling: stop time: %d", self._stop_time)
[文档] def analyse(self): """ Collect and analyze training performance data, support calls during and after training. The example shows above. """ Profiler._has_initialized = False self._cpu_dynamic_status = self._cpu_profiler.dynamic_status() _environment_check() self._cpu_profiler.stop() if self._device_target and self._device_target == DeviceTarget.CPU.value: self._cpu_analyse() if self._device_target and self._device_target == DeviceTarget.GPU.value: self._gpu_analyse() elif self._device_target and self._device_target == DeviceTarget.ASCEND.value: self._ascend_analyse() logger.info("Profiling: all the data have been analyzed.")
def _decide_device_target(self, kwargs): """Complete Profiler initialization according to device_target""" if self._device_target: cpu_profiler = c_expression.CPUProfiler self._cpu_profiler = cpu_profiler.get_instance() self._cpu_profiler.init(self._output_path) if self._device_target and self._device_target == DeviceTarget.CPU.value: self._cpu_profiler_init(kwargs) if self._device_target and self._device_target == DeviceTarget.GPU.value: self._gpu_profiler_init(kwargs) elif self._device_target and self._device_target == DeviceTarget.ASCEND.value: self._ascend_profiler_init(kwargs) def _cpu_profiler_init(self, kwargs): """Cpu profiler init.""" if context.get_context("mode") == context.PYNATIVE_MODE: raise RuntimeError("Pynative model is not supported on CPU currently.") self.start_profile = kwargs.pop("start_profile", True) if not isinstance(self.start_profile, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter start_profile must be bool, " f"but got type {type(self.start_profile)}") def _gpu_profiler_init(self, kwargs): """Gpu profiler init.""" # Setup and start MindData Profiling self._md_profiler = cde.GlobalContext.profiling_manager() self._md_profiler.init() if context.get_context("mode") == context.PYNATIVE_MODE: raise RuntimeError("Pynative model is not supported on GPU currently.") self._parse_parameter_for_gpu(kwargs) gpu_profiler = c_expression.GPUProfiler self._gpu_profiler = gpu_profiler.get_instance() self._gpu_profiler.init(self._output_path) if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group": self._dev_id = str(get_rank()) os.environ['DEVICE_ID'] = self._dev_id def _ascend_profiler_init(self, kwargs): """Ascend profiler init.""" # Setup and start MindData Profiling self._md_profiler = cde.GlobalContext.profiling_manager() self._md_profiler.init() self._init_time = int(time.time() * 10000000) logger.info("Profiling: profiling init time: %d", self._init_time) self._parse_parameter_for_ascend(kwargs) os.environ['DEVICE_ID'] = self._dev_id self._ascend_profiling_options = json.dumps(self._construct_profiling_options()) # Characters longer than 2048 are ignored, resulting in profiling option resolution errors if len(self._ascend_profiling_options) > 2048: msg = f"For '{self.__class__.__name__}', the environment parameter length exceeds " \ f"the limit (2048), please input valid parameters." logger.critical(msg) raise ValueError(msg) # use context interface to open profiling, for the new mindspore version(after 2020.5.21) self._ascend_profiler = c_expression.AscendProfiler.get_instance() self._ascend_profiler.init(self._output_path, int(self._dev_id), self._ascend_profiling_options) base_profiling_container_path = os.path.join(self._output_path, "container") container_path = os.path.join(base_profiling_container_path, self._dev_id) data_path = os.path.join(container_path, "data") data_path = validate_and_normalize_path(data_path) if not os.path.exists(data_path): os.makedirs(data_path, exist_ok=True) def _construct_profiling_options(self): """ Construct profiling options to determine which profiling data should be collected. """ profile_memory = "off" if self._profile_memory: profile_memory = "on" profiler_communication = "off" if self._profile_communication: profiler_communication = "on" fp_point = os.environ.get("PROFILING_FP_START", "") bp_point = os.environ.get("PROFILING_BP_END", "") profiling_options = { "output": self._output_path, "fp_point": fp_point, "bp_point": bp_point, "training_trace": "on", "task_trace": "on", "aic_metrics": "ArithmeticUtilization", "aicpu": "on", "profile_memory": profile_memory, "hccl": profiler_communication, "parallel_strategy": "on" } return profiling_options def _parse_parameter_for_gpu(self, kwargs): """Parse parameter in Proflier when the device target is GPU.""" self.start_profile = kwargs.pop("start_profile", True) if not isinstance(self.start_profile, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter start_profile must be bool, " f"but got type {type(self.start_profile)}") self._profile_communication = kwargs.pop("profile_communication", False) if not isinstance(self._profile_communication, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter profile_communication must be bool, " f"but got type {type(self._profile_communication)}") if self._profile_communication: raise RuntimeError(f"The parameter profile_communication is not supported on GPU currently.") self._profile_memory = kwargs.pop("profile_memory", False) if not isinstance(self._profile_memory, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter _profile_memory must be bool, " f"but got type {type(self._profile_memory)}") if self._profile_memory: raise RuntimeError(f"The parameter profile_memory is not supported on GPU currently.") def _parse_parameter_for_ascend(self, kwargs): """Parse parameter in Proflier when the device target is Ascend.""" self.start_profile = kwargs.pop("start_profile", True) if not isinstance(self.start_profile, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter start_profile must be bool, " f"but got type {type(self.start_profile)}") self._profile_communication = kwargs.pop("profile_communication", False) if not isinstance(self._profile_communication, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter profile_communication must be bool, " f"but got type {type(self._profile_communication)}") if self._profile_communication: hccl_option = {"output": self._output_path, "task_trace": "on"} os.environ['PROFILING_OPTIONS'] = json.dumps(hccl_option) if not self.start_profile: raise RuntimeError(f"For '{self.__class__.__name__}', the parameter profile_communication can " f"not be True while starting profiler in the process of training.") self._profile_memory = kwargs.pop("profile_memory", False) if not isinstance(self._profile_memory, bool): raise TypeError(f"For '{self.__class__.__name__}', the parameter profile_memory must be bool, " f"but got type '{type(self._profile_memory)}'") if kwargs: logger.warning("There are invalid params which don't work.") task_sink = os.getenv("GRAPH_OP_RUN") if task_sink and task_sink == "1": logger.warning(f"For '{self.__class__.__name__}', Profiling is not supported if set environment " f"'GRAPH_OP_RUN' value to 1, which means model training task is not sink.") def _set_ascend_job_id(self, ascend_job_id): """Set output_path for offline parsing performance data.""" self._ascend_job_id = validate_and_normalize_path(ascend_job_id) if not os.path.exists(self._ascend_job_id): msg = f"Invalid ascend_job_id: {self._ascend_job_id}, Please pass the absolute path of the JOB dir" logger.critical(msg) raise ValueError(msg) self._output_path, _ = os.path.split(self._ascend_job_id) def _is_offline_parser(self): """Return whether offline parser or online parser.""" if self._device_target and self._device_target == DeviceTarget.ASCEND.value: return bool(self._ascend_job_id) return False def _ascend_pynative_analyse(self): """Collect and analyse ascend pynative model performance data.""" op_intermediate_parser = OPIntermediateParser(self._output_path, self._rank_id) op_intermediate_parser.parser_pynative_op_type() op_intermediate_parser.parser_pynative_op_intermediate_detail() job_id = self._get_profiling_job_id() logger.info("Profiling: job id is %s ", job_id) self._check_output_path(output_path=self._output_path) source_path = os.path.join(self._output_path, job_id) MinddataParser.execute(source_path, self._output_path, self._rank_id) pipeline_parser = MinddataPipelineParser(self._output_path, self._rank_id, self._output_path) logger.info("Profiling: analyzing the minddata pipeline operator and queue.") pipeline_parser.parse() timeline_analyser = AscendTimelineGenerator(self._output_path, self._dev_id, self._rank_id, self._rank_size, context.get_context("mode")) timeline_analyser.init_pynative_timeline() size_limit = 100 * 1024 * 1024 # 100MB timeline_analyser.write_timeline(size_limit) timeline_analyser.write_timeline_summary() def _ascend_analyse(self): """Collect and analyse ascend performance data.""" self._rank_size = 1 if self._profile_communication and not GlobalComm.INITED: self._profile_communication = False if GlobalComm.INITED: self._rank_size = get_group_size() self._ascend_dynamic_status = self._ascend_profiler.dynamic_status() if self._has_started: self.stop() else: logger.info("No need to stop profiler because profiler has been stopped.") if context.get_context("mode") == context.PYNATIVE_MODE: self._ascend_pynative_analyse() else: self._ascend_graph_analyse() def _ascend_graph_memory_analyse(self, points): """Analyse memory usage info.""" if not self._profile_memory: return logger.info("Profiling: analyzing the memory usage info.") try: self._analyse_memory_usage(points) except (ProfilerIOException, ProfilerFileNotFoundException, ProfilerRawFileException) as err: logger.warning(err.message) finally: pass def _ascend_graph_hccl_analyse(self): """Analyse hccl profiler info.""" if not self._profile_communication: return logger.info("Profiling: analyzing the hccl profiler info.") try: self._analyse_hccl_info() except (ProfilerIOException, ProfilerFileNotFoundException, ProfilerRawFileException) as err: logger.warning(err.message) finally: pass def _ascend_graph_op_analyse(self, source_path): """ Ascend graph model hwts analyse. Returns: list[obj]: The list is: framework_parser, aicpu_data_parser, optime_parser, op_task_dict """ # parse hwts.log.data.45.dev file, and get task profiling data hwts_output_filename = self._hwts_output_filename_target + self._rank_id + ".txt" hwts_output_filename = os.path.join(self._output_path, hwts_output_filename) source_path = validate_and_normalize_path(source_path) hwts_output_filename = validate_and_normalize_path(hwts_output_filename) hwtslog_parser = HWTSLogParser(source_path, hwts_output_filename, self._ascend_dynamic_status) logger.info("Profiling: analyzing hwts data.") hwtslog_parser.execute() # parse Framework file, and get the relation of op and tasks framework_parser = FrameworkParser(source_path, self._rank_id, self._output_path) logger.info("Profiling: analyzing framework data.") framework_parser.parse() op_task_dict = framework_parser.to_task_id_full_op_name_dict() if not op_task_dict: raise RuntimeError('Profiling: fail to parse framework files.') # get op compute time from hwts data and framework data, write output_op_compute_time.txt opcompute_output_filename = self._opcompute_output_filename_target + self._rank_id + ".txt" opcompute_output_filename = os.path.join(self._output_path, opcompute_output_filename) opcompute_output_filename = validate_and_normalize_path(opcompute_output_filename) optime_parser = OPComputeTimeParser( hwts_output_filename, opcompute_output_filename, op_task_dict, self._output_path, self._rank_id ) logger.info("Profiling: analyzing the operation compute time.") optime_parser.execute() # parse DATA_PREPROCESS.dev.AICPU file, write output_data_preprocess_aicpu_x.txt output_data_preprocess_aicpu = self._aicpu_op_output_filename_target + self._rank_id + ".txt" output_data_preprocess_aicpu = os.path.join(self._output_path, output_data_preprocess_aicpu) output_data_preprocess_aicpu = validate_and_normalize_path(output_data_preprocess_aicpu) aicpu_data_parser = DataPreProcessParser(source_path, output_data_preprocess_aicpu, op_task_dict) logger.info("Profiling: analyzing the data preprocess data.") aicpu_data_parser.execute() return [framework_parser, aicpu_data_parser, optime_parser, op_task_dict] def _ascend_graph_op_compute_time_analyse(self): """Analyse op compute time info.""" try: self._analyser_op_info() except ProfilerException as err: logger.warning(err.message) finally: pass def _ascend_graph_step_trace_analyse(self, source_path, framework_parser): """Analyse step trace info.""" try: points, is_training_mode_flag = self._analyse_step_trace(source_path, framework_parser) except ProfilerException as err: logger.warning(err.message) finally: pass return points, is_training_mode_flag def _ascend_graph_minddata_analyse(self, source_path): """Analyse mindadata for ascend graph model.""" # Parsing minddata AICPU profiling logger.info("Profiling: analyzing the minddata AICPU data.") MinddataParser.execute(source_path, self._output_path, self._rank_id) # parse minddata pipeline operator and queue try: pipeline_parser = MinddataPipelineParser(self._output_path, self._rank_id, self._output_path) logger.info("Profiling: analyzing the minddata pipeline operator and queue.") pipeline_parser.parse() except ProfilerException as err: logger.warning(err.message) finally: pass # Analyze minddata information try: md_analyzer = MinddataProfilingAnalyzer(self._output_path, self._rank_id, self._output_path) logger.info("Profiling: analyzing the minddata information.") md_analyzer.analyze() except ProfilerException as err: logger.warning(err.message) finally: pass def _ascend_graph_analyse(self): """Ascend graph mode analyse.""" self._ascend_profiler.finalize() job_id = self._get_profiling_job_id() logger.info("Profiling: job id is %s ", job_id) self._check_output_path(output_path=self._output_path) source_path = os.path.join(self._output_path, job_id) op_parser_obj = self._ascend_graph_op_analyse(source_path) framework_parser, aicpu_data_parser, optime_parser, op_task_dict = op_parser_obj self._ascend_graph_minddata_analyse(source_path) # analyse op compute time info logger.info("Profiling: analyzing the operation compute time.") self._ascend_graph_op_compute_time_analyse() if self._ascend_dynamic_status and self._profile_communication: raise RuntimeError("The profile_communication parameter cannot be set on the dynamic shape network.") if self._ascend_dynamic_status and self._profile_memory: raise RuntimeError("The profile_memory parameter cannot be set on the dynamic shape network.") # analyse step trace info points = None is_training_mode_flag = False if not self._ascend_dynamic_status: logger.info("Profiling: analyzing the step trace data.") points, is_training_mode_flag = self._ascend_graph_step_trace_analyse(source_path, framework_parser) # analyse timeline info logger.info("Profiling: analyzing the timeline data.") try: self._analyse_timeline(aicpu_data_parser, optime_parser, source_path) except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err: logger.warning('Fail to write timeline data: %s', err) finally: pass self._ascend_graph_memory_analyse(points) self._ascend_graph_hccl_analyse() # get op FLOPs from aicore.data.x.slice.0 file, and compute FLOPS, write output_op_flops_x.txt if not self._ascend_dynamic_status: flops_parser = FlopsParser(source_path, self._output_path, op_task_dict, self._dev_id, self._rank_id, is_training_mode_flag) logger.info("Profiling: analyzing the operation FLOPs.") flops_parser.execute() if self._ascend_dynamic_status: dynamic_parser = DynamicFrameWorkParser(self._output_path, self._rank_id) dynamic_parser.write_dynamic_shape_data() def _ascend_pynative_start(self): """Ascend pynative mode start profiling.""" pynative_profiler = c_expression.PynativeProfiler self._pynative_profiler = pynative_profiler.get_instance() self._pynative_profiler.init(self._output_path) self._ascend_profiler.start() def _ascend_graph_start(self): """Ascend graph mode start profiling.""" self._ascend_profiler.start() def _gpu_analyse(self): """Collect and analyse gpu performance data.""" self._dev_id = context.get_context("device_id") self._rank_size = 1 self._gpu_dynamic_status = self._gpu_profiler.dynamic_status() if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group": self._dev_id = str(get_rank()) if GlobalComm.INITED: self._rank_size = get_group_size() if self._has_started: self.stop() else: logger.info("No need to stop profiler because profiler has been stopped.") reduce_op_type = self._get_step_reduce_op_type() timeline_generator = self._generate_timeline(reduce_op_type) # parse minddata pipeline operator and queue for GPU try: pipeline_parser = MinddataPipelineParser(self._output_path, self._dev_id, self._output_path) logger.info("Profiling: analyzing the minddata pipeline operator and queue for GPU.") pipeline_parser.parse() except ProfilerException as err: logger.warning(err.message) # Analyze minddata information try: md_analyzer = MinddataProfilingAnalyzer(self._output_path, self._dev_id, self._output_path) logger.info("Profiling: analyzing the minddata information.") md_analyzer.analyze() except ProfilerException as err: logger.warning(err.message) # analyse step trace info logger.info("Profiling: analyzing the step trace info.") try: self._analyse_step_trace( is_training_mode_flag=timeline_generator.check_op_name('Gradients'), is_gpu_kernel_async_launch_flag=timeline_generator.is_gpu_kernel_async_launch() ) except ProfilerException as err: logger.warning(err.message) finally: pass if self._gpu_dynamic_status: raise RuntimeError('Profiler does not support dynamic shape network on GPU platform currently.') logger.warning( '\nThe GPU supports only the training mode or inference mode, ' 'it does not support train and infer at the same time.' ) def _get_step_reduce_op_type(self): """Gets all communication operator names.""" step_trace_original_filename = f'step_trace_profiling_{self._dev_id}.txt' step_trace_file_path = os.path.join(self._output_path, step_trace_original_filename) step_trace_file_path = validate_and_normalize_path(step_trace_file_path) reduce_op_type = [] with open(step_trace_file_path, 'r') as f_obj: one_step_info = f_obj.readline().strip().split() # The communication operator starts at index 4. for reduce_item in one_step_info[4:]: reduce_op_type.append(reduce_item.split(',')[0].split('/')[-1]) return reduce_op_type def _cpu_analyse(self): """Collect and analyse cpu performance data.""" size_limit = 100 * 1024 * 1024 # 100MB try: timeline_generator = CpuTimelineGenerator(self._output_path, context.get_context("mode")) timeline_generator.init_timeline() timeline_generator.write_timeline(size_limit) timeline_generator.write_timeline_summary() except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err: logger.warning('Fail to write timeline data: %s', err) raise RuntimeError('Fail to write timeline data.') from err if self._cpu_dynamic_status: raise RuntimeError('Profiler does not support dynamic shape network on CPU platform currently.') def _analyse_step_trace(self, source_path=None, framework_parser=None, is_training_mode_flag=True, is_gpu_kernel_async_launch_flag=False): """ Analyse step trace data and save the result. Args: source_path (str): The directory that contains the step trace original data. framework_parser (FrameworkParser): The framework parse instance. is_training_mode_flag (bool): Whether in training mode or not. """ logger.info("Begin to parse step trace.") # construct output path dev_id = self._rank_id if self._device_target == DeviceTarget.ASCEND.value else self._dev_id step_trace_intermediate_file_path = os.path.join( self._output_path, f'step_trace_raw_{dev_id}_detail_time.csv' ) point_info_file_path = os.path.join( self._output_path, f'step_trace_point_info_{dev_id}.json' ) step_trace_intermediate_file_path = validate_and_normalize_path(step_trace_intermediate_file_path) point_info_file_path = validate_and_normalize_path(point_info_file_path) if self._device_target and self._device_target == DeviceTarget.GPU.value: input_file_path = os.path.join(self._output_path, f'step_trace_profiling_{self._dev_id}.txt') input_file_path = validate_and_normalize_path(input_file_path) parser = GpuStepTraceParser(input_dir=input_file_path, output_file_path=step_trace_intermediate_file_path, is_training_mode=is_training_mode_flag, is_gpu_kernel_async_launch=is_gpu_kernel_async_launch_flag) parser.parse_and_save() point_info = parser.record_point_info(point_info_file_path) else: # whether keep the first step skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME) point_info = framework_parser.point_info # recognize inference or training mode is_training_mode_flag = framework_parser.check_op_name("Gradients") # parser the step trace files and save the result to disk source_path = validate_and_normalize_path(source_path) parser = AscendStepTraceParser(input_dir=source_path, output_file_path=step_trace_intermediate_file_path, skip_first_step=skip_first_step_flag, is_training_mode=is_training_mode_flag) parser.set_task_id_op_name_dict(framework_parser.to_task_id_full_op_name_dict()) parser.parse_and_save() point_info = parser.record_point_info(point_info_file_path) # print parser result parser.show() logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path) logger.info("The point info is: %s", point_info) return point_info, is_training_mode_flag def _analyse_timeline(self, aicpu_parser, optime_parser, source_path): """ Analyse and parse timeline info. Args: aicpu_parser (DataPreProcessParser): The parser instance for AI CPU operator execution time calculation. optime_parser (OPComputeTimeParserParser): The parser instance for AI Core operator execution time calculation. """ timeline_analyser = AscendTimelineGenerator(self._output_path, self._dev_id, self._rank_id, self._rank_size, context.get_context("mode")) # Get framework info integrator = Integrator(self._output_path, self._rank_id) aicore_detail_data = integrator.get_aicore_detail_data() aicore_detail_data_size = len(aicore_detail_data) col_names = ['op_name', 'op_type', 'avg_execution_time', 'subgraph', 'full_op_name', 'op_info'] framework_info = { 'col_name': col_names, 'object': aicore_detail_data, 'size': aicore_detail_data_size } all_reduce_info = integrator.query_for_all_reduce() # Get timeline info logger.info('Start writing timeline info...') logger.info('Warm Prompt: It could take a few minutes if you are training ' 'with a complex network or more than 10 steps.') # Add info into timeline, such as AI CPU, AllReduce, framework info. aicpu_info = aicpu_parser.query_aicpu_data() min_cycle_counter = min(aicpu_parser.min_cycle_counter, optime_parser.min_cycle_counter) timeline_analyser.init_timeline(all_reduce_info, framework_info, aicpu_info, min_cycle_counter, source_path) size_limit = 100 * 1024 * 1024 # 100MB timeline_analyser.write_timeline(size_limit) timeline_analyser.write_timeline_summary() def _generate_timeline(self, reduce_op_type): """Used for gpu, generate timeline info, write to json format file.""" size_limit = 100 * 1024 * 1024 # 100MB try: timeline_generator = GpuTimelineGenerator(self._output_path, self._dev_id, self._rank_size, context.get_context("mode")) timeline_generator.init_timeline(reduce_op_type) timeline_generator.write_timeline(size_limit) timeline_generator.write_timeline_summary() return timeline_generator except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err: logger.warning('Fail to write timeline data: %s', err) raise RuntimeError('Fail to write timeline data.') from err def _analyse_memory_usage(self, points): """Analyse memory usage data.""" integrator = Integrator(self._output_path, self._rank_id) aicore_detail_data = integrator.get_aicore_detail_data() memory_parser = MemoryUsageParser(self._output_path, self._rank_id) memory_parser.init_memory_usage_info(aicore_detail_data, points) memory_parser.write_memory_files() def _get_profiling_job_id(self): """Get profiling job id, which was generated by ada service. Returns: str, profiling job id. """ if self._is_offline_parser(): # The self._ascend_job_id directory like "/../PROF***" or "/../JOB***". job_id = self._ascend_job_id.rstrip('/').split('/')[-1] if job_id.startswith('PROF'): device_dir = [dir for dir in os.listdir(self._ascend_job_id) if dir.startswith('device')] return os.path.join(job_id, device_dir[0]) return job_id job_id = "" job_dirs = filter(lambda item: item.startswith('JOB') or item.startswith('PROF') and \ os.path.isdir(os.path.join(self._output_path, item)), os.listdir(self._output_path)) sorted_job_dirs = sorted(job_dirs, key=lambda x: os.path.getmtime(os.path.join(self._output_path, x)), reverse=True) for dir_name in sorted_job_dirs: if dir_name.startswith('PROF'): prof_dir = os.path.join(self._output_path, dir_name) device_dir = [ dir for dir in os.listdir(prof_dir) \ if dir.startswith('device') and os.path.isdir(os.path.join(prof_dir, dir)) ] job_dir = os.path.join(self._output_path, dir_name, device_dir[0]) else: job_dir = os.path.join(self._output_path, dir_name) host_start_file_path = get_file_path(job_dir, "host_start.log") if host_start_file_path is None: logger.warning("Find profiling job path %s, but host_start.log not exist, " "profiler will ignore this job dir.", job_dir) continue training_device_id = host_start_file_path.split('.')[-1] if self._dev_id != training_device_id: logger.warning("Find profiling find job path %s, but not current training device id. " "Current training device id %s, but job path device id: %s, " "profiler will ignore this job dir.", job_dir, self._dev_id, training_device_id) continue if not os.listdir(os.path.join(job_dir, 'data')): continue job_start_time = self._parse_host_start_log(host_start_file_path) if not job_start_time: logger.warning("Find profiling job path %s, but fail to get job start info, " "profiler will ignore this job dir.", job_start_time) continue if int(job_start_time) < self._start_time: logger.warning("Find profiling job path %s, but start_time(%d) is earlier than this training " "start_time(%d), profiler will ignore this job dir.", job_dir, int(job_start_time), self._start_time) continue if dir_name.startswith('PROF'): job_id = os.path.join(dir_name, device_dir[0]) else: job_id = dir_name break if not job_id: msg = "Fail to get profiling job, output path is {}, " \ "please check whether job dir or prof dir(name startswith JOB or PROF) in output path " \ "was generated, or may be the device id from job dir dismatch the " \ "device_id in current process.".format(self._output_path) raise RuntimeError(msg) return job_id def _analyser_op_info(self): """Analyse the operator information.""" integrator = Integrator(self._output_path, self._rank_id) integrator.integrate() aicore_type_result = self._query_op_type_info() detail_file_path = os.path.join( self._output_path, 'output_op_compute_time_detail_{}.txt'.format(self._rank_id) ) fwrite_format(detail_file_path, data_source='title:op compute time') display_names = [ 'optype_name', 'compute_time(ms, per-step)', 'called_times(per-step)', 'percent' ] fwrite_format(detail_file_path, data_source=" ".join(display_names), is_print=True) fwrite_format(detail_file_path, data_source=aicore_type_result, is_print=True) op_type_order = [item[0] for item in aicore_type_result] aicore_detail_result = self._query_op_detail_info(op_type_order) fwrite_format(detail_file_path, data_source='', is_print=True) fwrite_format(detail_file_path, data_source='Detail:', is_print=True) fwrite_format(detail_file_path, data_source=" ".join(aicore_detail_result.get('col_name_detail')), is_print=True) fwrite_format(detail_file_path, data_source=aicore_detail_result.get('object'), is_print=True) def _query_op_type_info(self): """ Query AICORE operator type information. Returns: list[list], the AICORE operator type and execution time information. """ integrator = Integrator(self._output_path, self._rank_id) return integrator.get_aicore_data() def _query_op_detail_info(self, op_type_order): """ Query AICORE operator detail information. Args: op_type_order(list): The name of the op type in order. Returns: dict, the AICORE operator detail information. """ op_type_condition = {} if self._filt_optype_names: op_type_condition['not_in'] = self._filt_optype_names filter_condition = { 'op_type': op_type_condition, 'is_display_detail': False, } integrator = Integrator(self._output_path, self._rank_id) return integrator.query_and_sort_by_op_type(filter_condition, op_type_order) def _get_devid_rankid_and_devtarget(self): """Get device id and rank id and target of this training.""" device_target = "" dev_id = "" rank_id = "" try: dev_id = str(context.get_context("device_id")) device_target = context.get_context("device_target").lower() except ValueError as err: logger.error("Profiling: fail to get context, %s", err) if not dev_id or not dev_id.isdigit(): dev_id = os.getenv('DEVICE_ID') if not dev_id or not dev_id.isdigit(): dev_id = "0" logger.warning("Fail to get DEVICE_ID, use 0 instead.") if device_target and device_target not in [DeviceTarget.ASCEND.value, DeviceTarget.GPU.value, DeviceTarget.CPU.value]: msg = "Profiling: unsupported backend: %s" % device_target raise RuntimeError(msg) rank_id = os.getenv("RANK_ID") if not rank_id or not rank_id.isdigit(): rank_id = "0" logger.warning(f"For '{self.__class__.__name__}', fail to get RANK_ID from environment, " f"use 0 instead.") self._dev_id = dev_id self._device_target = device_target.lower() self._rank_id = rank_id def _get_output_path(self, kwargs): """Get output path of profiling data.""" if os.getenv("MS_DIAGNOSTIC_DATA_PATH") and kwargs.get("output_path") is not None: logger.warning("Both parameter output_path and environment variable MS_DIAGNOSTIC_DATA_PATH" " have values set, and the profiling data saving path is the value set " "in parameter output_path") if kwargs.get("output_path") is None: if "output_path" in kwargs: kwargs.pop("output_path") # Environment variables are mainly set for the convenience of cloud profiler. output_path = os.getenv("MS_DIAGNOSTIC_DATA_PATH") if output_path: self._output_path = validate_and_normalize_path(output_path) else: output_path = "data" self._output_path = validate_and_normalize_path(output_path) else: output_path = kwargs.pop("output_path") self._output_path = validate_and_normalize_path(output_path) self._output_path = os.path.join(self._output_path, "profiler") if not os.path.exists(self._output_path): os.makedirs(self._output_path, exist_ok=True) os.chmod(self._output_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) else: logger.warning("The target dir already exists. " "There may be some old profiling data, and they will be rewritten in the end.") def _analyse_hccl_info(self): """Analyse hccl info.""" hccl_path = os.path.join(self._output_path, "hccl_info_{}".format(self._rank_id)) if not os.path.exists(hccl_path): os.makedirs(hccl_path, exist_ok=True) os.chmod(hccl_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) logger.info("Start call the interface HCCLParseOP parsing hccl info...") logger.info('Warm Prompt: It could take a few minutes if you are training ' 'with a complex network or more than 10 steps.') # Call the interface HCCLParseOP parsing hccl info. from hccl_parser.entry import hccl_parse_op try: hccl_parse_op(self._dev_id, self._output_path, hccl_path, op_type='all') except ImportError as err: logger.critical("%s,please check if the hccl_parser-{version}-py3-none-any.whl is installed." "The hccl_parser-{version}-py3-none-any.whl package is usually located " "in the /usr/local/Ascend/tools Directory", err) raise ImportError(err) from err logger.info("Parse hccl info successfully.") logger.info("Start analyse hccl info.") hccl_parse = HcclParser(hccl_path, self._dev_id, self._rank_id, self._output_path) hccl_parse.parse() logger.info("Analyse hccl info successfully.")