[{"data":1,"prerenderedAt":842},["ShallowReactive",2],{"content-query-cujhWwaKLo":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":836,"_id":837,"_source":838,"_file":839,"_stem":840,"_extension":841},"/technology-blogs/zh/1691","zh",false,"","【MindSpore易点通】模型测试和验证","训练完成之后，需要测试模型在测试集上的表现。依据模型评估方式的不同，分为两种情况","2022-08-12","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2022/08/15/4730447b3ab34d5086a698572dc04ad7.png","technology-blogs","实践",{"type":15,"children":16,"toc":826},"root",[17,25,52,58,63,68,73,78,83,91,96,104,109,117,122,127,132,137,142,147,152,157,162,167,172,177,182,187,192,197,202,221,226,231,236,241,246,251,256,261,266,271,276,281,286,291,296,301,306,311,316,321,326,330,334,342,346,354,358,366,370,374,379,384,389,394,399,404,409,413,417,421,425,429,433,438,442,466,471,476,481,486,491,496,501,506,511,516,521,526,530,534,539,544,549,554,558,562,566,570,574,578,582,586,591,596,601,606,611,616,621,626,631,636,641,646,651,656,661,666,671,676,681,686,690,694,698,703,707,712,716,720,724,728,732,736,740,744,748,753,758,763,768,773,797,802,807,812,817],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"mindspore易点通模型测试和验证",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":29},"h2",{"id":28},"_1-模型测试",[30],{"type":18,"tag":31,"props":32,"children":33},"strong",{},[34,42,44],{"type":18,"tag":31,"props":35,"children":36},{},[37],{"type":18,"tag":31,"props":38,"children":39},{},[40],{"type":24,"value":41},"1",{"type":24,"value":43}," ",{"type":18,"tag":31,"props":45,"children":46},{},[47],{"type":18,"tag":31,"props":48,"children":49},{},[50],{"type":24,"value":51},"模型测试",{"type":18,"tag":53,"props":54,"children":55},"p",{},[56],{"type":24,"value":57},"在训练完成之后，需要测试模型在测试集上的表现。依据模型评估方式的不同，分以下两种情况",{"type":18,"tag":53,"props":59,"children":60},{},[61],{"type":24,"value":62},"1.评估方式在MindSpore中已实现",{"type":18,"tag":53,"props":64,"children":65},{},[66],{"type":24,"value":67},"MindSpore中提供了多种Metrics方式：Accuracy、Precision、Recall、F1、TopKCategoricalAccuracy、Top1CategoricalAccuracy、Top5CategoricalAccuracy、MSE、MAE、Loss 。在测试中调用MindSpore已有的评估函数，需要定义一个dict，包含要使用的评估方式，并在定义model时传入，后续调用model.eval()会返回一个dict，内容即为metrics的指标和结果。",{"type":18,"tag":53,"props":69,"children":70},{},[71],{"type":24,"value":72},"...def test_net(network, model, test_data_path, test_batch):",{"type":18,"tag":53,"props":74,"children":75},{},[76],{"type":24,"value":77},"\"\"\"define the evaluation method\"\"\"",{"type":18,"tag":53,"props":79,"children":80},{},[81],{"type":24,"value":82},"print(\"============== Start Testing ==============\")",{"type":18,"tag":53,"props":84,"children":85},{},[86],{"type":18,"tag":31,"props":87,"children":88},{},[89],{"type":24,"value":90},"# load the saved model for evaluation",{"type":18,"tag":53,"props":92,"children":93},{},[94],{"type":24,"value":95},"param_dict = load_checkpoint(\"./train_resnet_cifar10-1_390.ckpt\")",{"type":18,"tag":53,"props":97,"children":98},{},[99],{"type":18,"tag":31,"props":100,"children":101},{},[102],{"type":24,"value":103},"#load parameter to the network",{"type":18,"tag":53,"props":105,"children":106},{},[107],{"type":24,"value":108},"load_param_into_net(network, param_dict)",{"type":18,"tag":53,"props":110,"children":111},{},[112],{"type":18,"tag":31,"props":113,"children":114},{},[115],{"type":24,"value":116},"#load testing dataset",{"type":18,"tag":53,"props":118,"children":119},{},[120],{"type":24,"value":121},"ds_test = create_dataset(test_data_path, do_train=False,",{"type":18,"tag":53,"props":123,"children":124},{},[125],{"type":24,"value":126},"batch_size=test_batch)",{"type":18,"tag":53,"props":128,"children":129},{},[130],{"type":24,"value":131},"acc = model.eval(ds_test, dataset_sink_mode=False)",{"type":18,"tag":53,"props":133,"children":134},{},[135],{"type":24,"value":136},"print(\"============== test result:{} ==============\".format(acc))",{"type":18,"tag":53,"props":138,"children":139},{},[140],{"type":24,"value":141},"if __name__ == \"__main__\":",{"type":18,"tag":53,"props":143,"children":144},{},[145],{"type":24,"value":146},"...",{"type":18,"tag":53,"props":148,"children":149},{},[150],{"type":24,"value":151},"net = resnet()",{"type":18,"tag":53,"props":153,"children":154},{},[155],{"type":24,"value":156},"loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True,",{"type":18,"tag":53,"props":158,"children":159},{},[160],{"type":24,"value":161},"reduction='mean')",{"type":18,"tag":53,"props":163,"children":164},{},[165],{"type":24,"value":166},"opt = nn.SGD(net.trainable_params(), LR_ORI, MOMENTUM_ORI, WEIGHT_DECAY)",{"type":18,"tag":53,"props":168,"children":169},{},[170],{"type":24,"value":171},"metrics = {",{"type":18,"tag":53,"props":173,"children":174},{},[175],{"type":24,"value":176},"'accuracy': nn.Accuracy(),",{"type":18,"tag":53,"props":178,"children":179},{},[180],{"type":24,"value":181},"'loss': nn.Loss()",{"type":18,"tag":53,"props":183,"children":184},{},[185],{"type":24,"value":186},"}",{"type":18,"tag":53,"props":188,"children":189},{},[190],{"type":24,"value":191},"model = Model(net, loss, opt, metrics=metrics)",{"type":18,"tag":53,"props":193,"children":194},{},[195],{"type":24,"value":196},"test_net(net, model_constructed, TEST_PATH, TEST_BATCH_SIZE)",{"type":18,"tag":53,"props":198,"children":199},{},[200],{"type":24,"value":201},"2.评估方式在MindSpore中没有实现",{"type":18,"tag":53,"props":203,"children":204},{},[205,207,219],{"type":24,"value":206},"如果MindSpore中的评估函数不能满足要求，可参考",{"type":18,"tag":208,"props":209,"children":213},"a",{"href":210,"rel":211},"https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/nn/metrics/accuracy.py",[212],"nofollow",[214],{"type":18,"tag":31,"props":215,"children":216},{},[217],{"type":24,"value":218},"accuracy.py",{"type":24,"value":220}," 通过继承Metric基类完成Metric定义之后，并重写clear,updata,eval三个方法即可。通过调用model.predict()接口，得到网络输出后，按照自定义评估标准计算结果。",{"type":18,"tag":53,"props":222,"children":223},{},[224],{"type":24,"value":225},"下面以计算测试集精度为例，实现自定义Metrics：",{"type":18,"tag":53,"props":227,"children":228},{},[229],{"type":24,"value":230},"class AccuracyV2(EvaluationBase):",{"type":18,"tag":53,"props":232,"children":233},{},[234],{"type":24,"value":235},"def __init__(self, eval_type='classification'):",{"type":18,"tag":53,"props":237,"children":238},{},[239],{"type":24,"value":240},"super(AccuracyV2, self).__init__(eval_type)",{"type":18,"tag":53,"props":242,"children":243},{},[244],{"type":24,"value":245},"self.clear()",{"type":18,"tag":53,"props":247,"children":248},{},[249],{"type":24,"value":250},"def clear(self):",{"type":18,"tag":53,"props":252,"children":253},{},[254],{"type":24,"value":255},"\"\"\"Clears the internal evaluation result.\"\"\"",{"type":18,"tag":53,"props":257,"children":258},{},[259],{"type":24,"value":260},"self._correct_num = 0",{"type":18,"tag":53,"props":262,"children":263},{},[264],{"type":24,"value":265},"self._total_num = 0",{"type":18,"tag":53,"props":267,"children":268},{},[269],{"type":24,"value":270},"def update(self, output_y, label_input):",{"type":18,"tag":53,"props":272,"children":273},{},[274],{"type":24,"value":275},"y_pred = self._convert_data(output_y)",{"type":18,"tag":53,"props":277,"children":278},{},[279],{"type":24,"value":280},"y = self._convert_data(label_input)",{"type":18,"tag":53,"props":282,"children":283},{},[284],{"type":24,"value":285},"indices = y_pred.argmax(axis=1)",{"type":18,"tag":53,"props":287,"children":288},{},[289],{"type":24,"value":290},"results = (np.equal(indices, y) * 1).reshape(-1)",{"type":18,"tag":53,"props":292,"children":293},{},[294],{"type":24,"value":295},"self._correct_num += results.sum()",{"type":18,"tag":53,"props":297,"children":298},{},[299],{"type":24,"value":300},"self._total_num += label_input.shape[0]",{"type":18,"tag":53,"props":302,"children":303},{},[304],{"type":24,"value":305},"def eval(self):",{"type":18,"tag":53,"props":307,"children":308},{},[309],{"type":24,"value":310},"if self._total_num == 0:",{"type":18,"tag":53,"props":312,"children":313},{},[314],{"type":24,"value":315},"raise RuntimeError('Accuary can not be calculated')",{"type":18,"tag":53,"props":317,"children":318},{},[319],{"type":24,"value":320},"return self._correct_num / self._total_num",{"type":18,"tag":53,"props":322,"children":323},{},[324],{"type":24,"value":325},"def test_net(network, model, test_data_path, test_batch):",{"type":18,"tag":53,"props":327,"children":328},{},[329],{"type":24,"value":77},{"type":18,"tag":53,"props":331,"children":332},{},[333],{"type":24,"value":82},{"type":18,"tag":53,"props":335,"children":336},{},[337],{"type":18,"tag":31,"props":338,"children":339},{},[340],{"type":24,"value":341},"# Load the saved model for evaluation",{"type":18,"tag":53,"props":343,"children":344},{},[345],{"type":24,"value":95},{"type":18,"tag":53,"props":347,"children":348},{},[349],{"type":18,"tag":31,"props":350,"children":351},{},[352],{"type":24,"value":353},"# Load parameter to the network",{"type":18,"tag":53,"props":355,"children":356},{},[357],{"type":24,"value":108},{"type":18,"tag":53,"props":359,"children":360},{},[361],{"type":18,"tag":31,"props":362,"children":363},{},[364],{"type":24,"value":365},"# Load testing dataset",{"type":18,"tag":53,"props":367,"children":368},{},[369],{"type":24,"value":121},{"type":18,"tag":53,"props":371,"children":372},{},[373],{"type":24,"value":126},{"type":18,"tag":53,"props":375,"children":376},{},[377],{"type":24,"value":378},"metric = AccuracyV2()",{"type":18,"tag":53,"props":380,"children":381},{},[382],{"type":24,"value":383},"metric.clear()",{"type":18,"tag":53,"props":385,"children":386},{},[387],{"type":24,"value":388},"for data, label in ds_test.create_tuple_iterator():",{"type":18,"tag":53,"props":390,"children":391},{},[392],{"type":24,"value":393},"output = model.predict(data)",{"type":18,"tag":53,"props":395,"children":396},{},[397],{"type":24,"value":398},"metric.update(output, label)",{"type":18,"tag":53,"props":400,"children":401},{},[402],{"type":24,"value":403},"results = metric.eval()",{"type":18,"tag":53,"props":405,"children":406},{},[407],{"type":24,"value":408},"print(\"============== New Metric:{} ==============\".format(results))",{"type":18,"tag":53,"props":410,"children":411},{},[412],{"type":24,"value":141},{"type":18,"tag":53,"props":414,"children":415},{},[416],{"type":24,"value":146},{"type":18,"tag":53,"props":418,"children":419},{},[420],{"type":24,"value":151},{"type":18,"tag":53,"props":422,"children":423},{},[424],{"type":24,"value":156},{"type":18,"tag":53,"props":426,"children":427},{},[428],{"type":24,"value":161},{"type":18,"tag":53,"props":430,"children":431},{},[432],{"type":24,"value":166},{"type":18,"tag":53,"props":434,"children":435},{},[436],{"type":24,"value":437},"model_constructed = Model(net, loss, opt)",{"type":18,"tag":53,"props":439,"children":440},{},[441],{"type":24,"value":196},{"type":18,"tag":26,"props":443,"children":445},{"id":444},"_2-边训练边验证",[446],{"type":18,"tag":31,"props":447,"children":448},{},[449,457,458],{"type":18,"tag":31,"props":450,"children":451},{},[452],{"type":18,"tag":31,"props":453,"children":454},{},[455],{"type":24,"value":456},"2",{"type":24,"value":43},{"type":18,"tag":31,"props":459,"children":460},{},[461],{"type":18,"tag":31,"props":462,"children":463},{},[464],{"type":24,"value":465},"边训练边验证",{"type":18,"tag":53,"props":467,"children":468},{},[469],{"type":24,"value":470},"在训练的过程中，在验证集上测试模型的效果。目前MindSpore有两种方式。",{"type":18,"tag":53,"props":472,"children":473},{},[474],{"type":24,"value":475},"1、交替调用model.train()和model.eval() ，实现边训练边验证。",{"type":18,"tag":53,"props":477,"children":478},{},[479],{"type":24,"value":480},"...def train_and_val(model, dataset_train, dataset_val, steps_per_train,",{"type":18,"tag":53,"props":482,"children":483},{},[484],{"type":24,"value":485},"epoch_max, evaluation_interval):",{"type":18,"tag":53,"props":487,"children":488},{},[489],{"type":24,"value":490},"config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_train,",{"type":18,"tag":53,"props":492,"children":493},{},[494],{"type":24,"value":495},"keep_checkpoint_max=epoch_max)",{"type":18,"tag":53,"props":497,"children":498},{},[499],{"type":24,"value":500},"ckpoint_cb = ModelCheckpoint(prefix=\"train_resnet_cifar10\",",{"type":18,"tag":53,"props":502,"children":503},{},[504],{"type":24,"value":505},"directory=\"./\", config=config_ck)",{"type":18,"tag":53,"props":507,"children":508},{},[509],{"type":24,"value":510},"model.train(evaluation_interval, dataset_train,",{"type":18,"tag":53,"props":512,"children":513},{},[514],{"type":24,"value":515},"callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True)",{"type":18,"tag":53,"props":517,"children":518},{},[519],{"type":24,"value":520},"acc = model.eval(dataset_val, dataset_sink_mode=False)",{"type":18,"tag":53,"props":522,"children":523},{},[524],{"type":24,"value":525},"print(\"============== Evaluation:{} ==============\".format(acc))",{"type":18,"tag":53,"props":527,"children":528},{},[529],{"type":24,"value":141},{"type":18,"tag":53,"props":531,"children":532},{},[533],{"type":24,"value":146},{"type":18,"tag":53,"props":535,"children":536},{},[537],{"type":24,"value":538},"ds_train, steps_per_epoch_train = create_dataset(TRAIN_PATH,",{"type":18,"tag":53,"props":540,"children":541},{},[542],{"type":24,"value":543},"do_train=True, batch_size=TRAIN_BATCH_SIZE, repeat_num=1)",{"type":18,"tag":53,"props":545,"children":546},{},[547],{"type":24,"value":548},"ds_val, steps_per_epoch_val = create_dataset(VAL_PATH, do_train=False,",{"type":18,"tag":53,"props":550,"children":551},{},[552],{"type":24,"value":553},"batch_size=VAL_BATCH_SIZE, repeat_num=1)",{"type":18,"tag":53,"props":555,"children":556},{},[557],{"type":24,"value":151},{"type":18,"tag":53,"props":559,"children":560},{},[561],{"type":24,"value":156},{"type":18,"tag":53,"props":563,"children":564},{},[565],{"type":24,"value":161},{"type":18,"tag":53,"props":567,"children":568},{},[569],{"type":24,"value":166},{"type":18,"tag":53,"props":571,"children":572},{},[573],{"type":24,"value":171},{"type":18,"tag":53,"props":575,"children":576},{},[577],{"type":24,"value":176},{"type":18,"tag":53,"props":579,"children":580},{},[581],{"type":24,"value":181},{"type":18,"tag":53,"props":583,"children":584},{},[585],{"type":24,"value":186},{"type":18,"tag":53,"props":587,"children":588},{},[589],{"type":24,"value":590},"net = Model(net, loss, opt, metrics=metrics)",{"type":18,"tag":53,"props":592,"children":593},{},[594],{"type":24,"value":595},"for i in range(int(EPOCH_MAX / EVAL_INTERVAL)):",{"type":18,"tag":53,"props":597,"children":598},{},[599],{"type":24,"value":600},"train_and_val(net, ds_train, ds_val, steps_per_epoch_train,",{"type":18,"tag":53,"props":602,"children":603},{},[604],{"type":24,"value":605},"EPOCH_MAX, EVAL_INTERVAL)",{"type":18,"tag":53,"props":607,"children":608},{},[609],{"type":24,"value":610},"2、MindSpore通过调用model.train接口，在callbacks中传入自定义的EvalCallBack实例，进行训练并验证。",{"type":18,"tag":53,"props":612,"children":613},{},[614],{"type":24,"value":615},"class EvalCallBack(Callback):",{"type":18,"tag":53,"props":617,"children":618},{},[619],{"type":24,"value":620},"def __init__(self, model, eval_dataset, eval_epoch, result_evaluation):",{"type":18,"tag":53,"props":622,"children":623},{},[624],{"type":24,"value":625},"self.model = model",{"type":18,"tag":53,"props":627,"children":628},{},[629],{"type":24,"value":630},"self.eval_dataset = eval_dataset",{"type":18,"tag":53,"props":632,"children":633},{},[634],{"type":24,"value":635},"self.eval_epoch = eval_epoch",{"type":18,"tag":53,"props":637,"children":638},{},[639],{"type":24,"value":640},"self.result_evaluation = result_evaluation",{"type":18,"tag":53,"props":642,"children":643},{},[644],{"type":24,"value":645},"def epoch_end(self, run_context):",{"type":18,"tag":53,"props":647,"children":648},{},[649],{"type":24,"value":650},"cb_param = run_context.original_args()",{"type":18,"tag":53,"props":652,"children":653},{},[654],{"type":24,"value":655},"cur_epoch = cb_param.cur_epoch_num",{"type":18,"tag":53,"props":657,"children":658},{},[659],{"type":24,"value":660},"if cur_epoch % self.eval_epoch == 0:",{"type":18,"tag":53,"props":662,"children":663},{},[664],{"type":24,"value":665},"acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)",{"type":18,"tag":53,"props":667,"children":668},{},[669],{"type":24,"value":670},"self.result_evaluation[\"epoch\"].append(cur_epoch)",{"type":18,"tag":53,"props":672,"children":673},{},[674],{"type":24,"value":675},"self.result_evaluation[\"acc\"].append(acc[\"accuracy\"])",{"type":18,"tag":53,"props":677,"children":678},{},[679],{"type":24,"value":680},"self.result_evaluation[\"loss\"].append(acc[\"loss\"])",{"type":18,"tag":53,"props":682,"children":683},{},[684],{"type":24,"value":685},"print(acc)",{"type":18,"tag":53,"props":687,"children":688},{},[689],{"type":24,"value":141},{"type":18,"tag":53,"props":691,"children":692},{},[693],{"type":24,"value":146},{"type":18,"tag":53,"props":695,"children":696},{},[697],{"type":24,"value":538},{"type":18,"tag":53,"props":699,"children":700},{},[701],{"type":24,"value":702},"do_train=True, batch_size=TRAIN_BATCH_SIZE, repeat_num=REPEAT_SIZE)",{"type":18,"tag":53,"props":704,"children":705},{},[706],{"type":24,"value":548},{"type":18,"tag":53,"props":708,"children":709},{},[710],{"type":24,"value":711},"batch_size=VAL_BATCH_SIZE, repeat_num=REPEAT_SIZE)",{"type":18,"tag":53,"props":713,"children":714},{},[715],{"type":24,"value":151},{"type":18,"tag":53,"props":717,"children":718},{},[719],{"type":24,"value":156},{"type":18,"tag":53,"props":721,"children":722},{},[723],{"type":24,"value":161},{"type":18,"tag":53,"props":725,"children":726},{},[727],{"type":24,"value":166},{"type":18,"tag":53,"props":729,"children":730},{},[731],{"type":24,"value":171},{"type":18,"tag":53,"props":733,"children":734},{},[735],{"type":24,"value":176},{"type":18,"tag":53,"props":737,"children":738},{},[739],{"type":24,"value":181},{"type":18,"tag":53,"props":741,"children":742},{},[743],{"type":24,"value":186},{"type":18,"tag":53,"props":745,"children":746},{},[747],{"type":24,"value":590},{"type":18,"tag":53,"props":749,"children":750},{},[751],{"type":24,"value":752},"result_eval = {\"epoch\": [], \"acc\": [], \"loss\": []}",{"type":18,"tag":53,"props":754,"children":755},{},[756],{"type":24,"value":757},"eval_cb = EvalCallBack(net, ds_val, EVAL_PER_EPOCH, result_eval)",{"type":18,"tag":53,"props":759,"children":760},{},[761],{"type":24,"value":762},"net.train(EPOCH_MAX, ds_train,",{"type":18,"tag":53,"props":764,"children":765},{},[766],{"type":24,"value":767},"callbacks=[ckpoint_cb, LossMonitor(), eval_cb],",{"type":18,"tag":53,"props":769,"children":770},{},[771],{"type":24,"value":772},"dataset_sink_mode=True, sink_size=steps_per_epoch_train)",{"type":18,"tag":26,"props":774,"children":776},{"id":775},"_3-样例代码使用说明",[777],{"type":18,"tag":31,"props":778,"children":779},{},[780,788,789],{"type":18,"tag":31,"props":781,"children":782},{},[783],{"type":18,"tag":31,"props":784,"children":785},{},[786],{"type":24,"value":787},"3",{"type":24,"value":43},{"type":18,"tag":31,"props":790,"children":791},{},[792],{"type":18,"tag":31,"props":793,"children":794},{},[795],{"type":24,"value":796},"样例代码使用说明",{"type":18,"tag":53,"props":798,"children":799},{},[800],{"type":24,"value":801},"本文的样例代码是一个Resnet50在Cifar10上训练的分类网络，采用datasets.Cifar10Dataset接口读取二进制版本的CIFAR-10数据集，因此下载CIFAR-10 binary version (suitable for C programs)，并在代码中配置好数据路径。",{"type":18,"tag":53,"props":803,"children":804},{},[805],{"type":24,"value":806},"启动命令：",{"type":18,"tag":53,"props":808,"children":809},{},[810],{"type":24,"value":811},"python xxx.py --data_path=xxx --epoch_num=xxx",{"type":18,"tag":53,"props":813,"children":814},{},[815],{"type":24,"value":816},"运行脚本，可以看到网络输出结果：",{"type":18,"tag":53,"props":818,"children":819},{},[820],{"type":18,"tag":821,"props":822,"children":825},"img",{"alt":823,"src":824},"æè·.PNG","https://fileserver.developer.huaweicloud.com/FileServer/getFile/cmtybbs/5e4/e02/8f7/550440a1fe5e4e028f77e5cf18005adc.20220812080015.83232642210810927942825969590649:50530814003507:2400:2077495B473D9E4451ABCEF7AAF48C23D0FA8AC4E009B6ADC403559F1186A409.png",[],{"title":7,"searchDepth":827,"depth":827,"links":828},4,[829,832,834],{"id":28,"depth":830,"text":831},2,"1 模型测试",{"id":444,"depth":830,"text":833},"2 边训练边验证",{"id":775,"depth":830,"text":835},"3 样例代码使用说明","markdown","content:technology-blogs:zh:1691.md","content","technology-blogs/zh/1691.md","technology-blogs/zh/1691","md",1776506114864]