Class TrainCfg

Class Documentation

class TrainCfg

Public Functions

TrainCfg() = default

Constructor of training config.

inline TrainCfg(const TrainCfg &rhs)

Constructor of training config.

参数

rhs[in] The training config.

~TrainCfg() = default

Destructor of training config.

inline std::vector<std::string> GetLossName() const

obtain part of the name that identify a loss kernel.

返回

loss_name.

inline void SetLossName(const std::vector<std::string> &loss_name)

Set part of the name that identify a loss kernel.

参数

loss_name[in] define part of the name that identify a loss kernel.

Public Members

OptimizationLevel optimization_level_ = kO0

Optimization level.

MixPrecisionCfg mix_precision_cfg_

Mix precision configuration.

bool accumulate_gradients_ = false

If accumulate gradients is used.