Class TrainCfg
Defined in File cfg.h
Class Documentation
-
class TrainCfg
Public Functions
-
TrainCfg() = default
Constructor of training config.
-
inline TrainCfg(const TrainCfg &rhs)
Constructor of training config.
- 参数
rhs – [in] The training config.
-
~TrainCfg() = default
Destructor of training config.
-
inline std::vector<std::string> GetLossName() const
obtain part of the name that identify a loss kernel.
- 返回
loss_name.
-
inline void SetLossName(const std::vector<std::string> &loss_name)
Set part of the name that identify a loss kernel.
- 参数
loss_name – [in] define part of the name that identify a loss kernel.
Public Members
-
OptimizationLevel optimization_level_ = kO0
Optimization level.
-
MixPrecisionCfg mix_precision_cfg_
Mix precision configuration.
-
bool accumulate_gradients_ = false
If accumulate gradients is used.
-
TrainCfg() = default