:gitee_url: https://gitee.com/mindspore/docs .. _program_listing_file_include_callback.h: Program Listing for File callback.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/callback.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H #define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H #include #include #include #include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" namespace mindspore { class Model; class ModelImpl; class CallbackImpl; struct TrainCallBackData { TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {} bool train_mode_; unsigned int epoch_; unsigned int step_ = 0; Model *model_; }; enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF }; class TrainCallBack { public: virtual ~TrainCallBack() = default; virtual void Begin(const TrainCallBackData &cb_data) {} virtual void End(const TrainCallBackData &cb_data) {} virtual void EpochBegin(const TrainCallBackData &cb_data) {} virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; } virtual void StepBegin(const TrainCallBackData &cb_data) {} virtual void StepEnd(const TrainCallBackData &cb_data) {} protected: friend class Model; friend class ModelImpl; CallbackImpl* callback_impl_ = nullptr; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H