:gitee_url: https://gitee.com/mindspore/docs .. _program_listing_file_include_context.h: Program Listing for File context.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/context.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H #include #include #include #include #include "include/api/types.h" #include "include/api/dual_abi_helper.h" namespace mindspore { enum DeviceType { kCPU = 0, kGPU, kKirinNPU, kAscend910, kAscend310, kHexagonDSP = 6, // add new type here kInvalidDeviceType = 100, }; class Allocator; class Delegate; class DeviceInfoContext; class Context { public: struct Data; Context(); ~Context() = default; void SetThreadNum(int32_t thread_num); int32_t GetThreadNum() const; void SetThreadAffinity(int mode); int GetThreadAffinityMode() const; void SetThreadAffinity(const std::vector &core_list); std::vector GetThreadAffinityCoreList() const; void SetEnableParallel(bool is_parallel); bool GetEnableParallel() const; void SetDelegate(const std::shared_ptr &delegate); std::shared_ptr GetDelegate() const; std::vector> &MutableDeviceInfo(); private: std::shared_ptr data_; }; class DeviceInfoContext : public std::enable_shared_from_this { public: struct Data; DeviceInfoContext(); virtual ~DeviceInfoContext() = default; virtual enum DeviceType GetDeviceType() const = 0; template std::shared_ptr Cast() { static_assert(std::is_base_of::value, "Wrong cast type."); if (GetDeviceType() != T().GetDeviceType()) { return nullptr; } return std::static_pointer_cast(shared_from_this()); } inline std::string GetProvider() const; inline void SetProvider(const std::string &provider); inline std::string GetProviderDevice() const; inline void SetProviderDevice(const std::string &device); void SetAllocator(const std::shared_ptr &allocator); std::shared_ptr GetAllocator() const; protected: std::vector GetProviderChar() const; void SetProvider(const std::vector &provider); std::vector GetProviderDeviceChar() const; void SetProviderDevice(const std::vector &device); std::shared_ptr data_; }; std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); } void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); } std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); } void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); } class CPUDeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; void SetEnableFP16(bool is_fp16); bool GetEnableFP16() const; }; class KirinNPUDeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; void SetFrequency(int frequency); int GetFrequency() const; }; class GPUDeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; void SetDeviceID(uint32_t device_id); uint32_t GetDeviceID() const; inline void SetPrecisionMode(const std::string &precision_mode); inline std::string GetPrecisionMode() const; void SetEnableFP16(bool is_fp16); bool GetEnableFP16() const; private: void SetPrecisionMode(const std::vector &precision_mode); std::vector GetPrecisionModeChar() const; }; void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { SetPrecisionMode(StringToChar(precision_mode)); } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } class Ascend910DeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; void SetDeviceID(uint32_t device_id); uint32_t GetDeviceID() const; }; class Ascend310DeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; void SetDeviceID(uint32_t device_id); uint32_t GetDeviceID() const; inline void SetInsertOpConfigPath(const std::string &cfg_path); inline std::string GetInsertOpConfigPath() const; inline void SetInputFormat(const std::string &format); inline std::string GetInputFormat() const; inline void SetInputShape(const std::string &shape); inline std::string GetInputShape() const; void SetInputShapeMap(const std::map> &shape); std::map> GetInputShapeMap() const; void SetDynamicBatchSize(const std::vector &dynamic_batch_size); inline std::string GetDynamicBatchSize() const; void SetOutputType(enum DataType output_type); enum DataType GetOutputType() const; inline void SetPrecisionMode(const std::string &precision_mode); inline std::string GetPrecisionMode() const; inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); inline std::string GetOpSelectImplMode() const; inline void SetFusionSwitchConfigPath(const std::string &cfg_path); inline std::string GetFusionSwitchConfigPath() const; // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); inline std::string GetBufferOptimizeMode() const; private: void SetInsertOpConfigPath(const std::vector &cfg_path); std::vector GetInsertOpConfigPathChar() const; void SetInputFormat(const std::vector &format); std::vector GetInputFormatChar() const; void SetInputShape(const std::vector &shape); std::vector GetInputShapeChar() const; std::vector GetDynamicBatchSizeChar() const; void SetPrecisionMode(const std::vector &precision_mode); std::vector GetPrecisionModeChar() const; void SetOpSelectImplMode(const std::vector &op_select_impl_mode); std::vector GetOpSelectImplModeChar() const; void SetFusionSwitchConfigPath(const std::vector &cfg_path); std::vector GetFusionSwitchConfigPathChar() const; void SetBufferOptimizeMode(const std::vector &buffer_optimize_mode); std::vector GetBufferOptimizeModeChar() const; }; void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { SetInsertOpConfigPath(StringToChar(cfg_path)); } std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { SetPrecisionMode(StringToChar(precision_mode)); } std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { SetOpSelectImplMode(StringToChar(op_select_impl_mode)); } std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { SetFusionSwitchConfigPath(StringToChar(cfg_path)); } std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { return CharToString(GetFusionSwitchConfigPathChar()); } void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); } std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } class HexagonDspDeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kHexagonDSP; }; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H