# 基于Java接口实现端侧训练 `Android` `Java` `全流程` `模型加载` `模型训练` `数据准备` `初级` `中级` `高级` [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.3/docs/lite/docs/source_zh_cn/quick_start/train_lenet_java.md) ## 概述 本教程通过构建并部署Java版本的LeNet网络的训练,演示MindSpore Lite端侧训练Java接口的使用。 首先指导您在本地成功训练LeNet模型,然后讲解示例代码。 ## 准备 ### 环境要求 - 系统环境:Linux x86_64,推荐使用Ubuntu 18.04.02LTS - 软件依赖 - [Git](https://git-scm.com/downloads) >= 2.28.0 - [Maven](https://maven.apache.org/download.cgi) >= 3.3 - [OpenJDK](https://openjdk.java.net/install/) >= 1.8 ### 下载MindSpore并编译端侧训练Java包 首先克隆源码,然后编译MindSpore Lite端侧训练Java包,`Linux`指令如下: ```bash git clone https://gitee.com/mindspore/mindspore.git -b r1.3 cd mindspore bash build.sh -I x86_64 -j8 ``` 更详细的编译说明,请参考[编译MindSpore Lite](https://www.mindspore.cn/lite/docs/zh-CN/r1.3/use/build.html)章节。 本教程使用的示例源码在`mindspore/lite/examples/train_lenet_java`目录。 ### 下载数据集 示例中的`MNIST`数据集由10类28*28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。 > MNIST数据集官网下载地址:,共4个下载链接,分别是训练数据、训练标签、测试数据和测试标签。 下载并解压到本地,解压后的训练和测试集分别存放于`/PATH/MNIST_Data/train`和`/PATH/MNIST_Data/test`路径下。 目录结构如下: ```text MNIST_Data/ ├── test │   ├── t10k-images-idx3-ubyte │   └── t10k-labels-idx1-ubyte └── train ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte ``` ## 部署应用 ### 构建与运行 1. 首先进入示例工程所在目录,运行示例程序,命令如下: ```bash cd /codes/mindspore/mindspore/lite/examples/train_lenet_java ./prepare_and_run.sh -D /PATH/MNIST_Data/ -r ../../../../output/mindspore-lite-${version}-linux-x64.tar.gz ``` > ../resources/model/lenet_tod.ms是示例工程中预置的LeNet训练模型,您也可以参考[训练模型转换](https://www.mindspore.cn/lite/docs/zh-CN/r1.3/use/converter_train.html),自行转换出LeNet模型。 > > /PATH/MNIST_Data/是MNIST数据集所在路径。 示例运行结果如下: ```text MindSpore Lite 1.3.0 ==========Loading Model, Create Train Session============= Model path is ../model/lenet_tod.ms batch_size: 4 virtual batch multiplier: 16 ==========Initing DataSet================ train data cnt: 60000 test data cnt: 10000 ==========Training Model=================== step_500: Loss is 0.05553353 [min=0.010149269] max_accc=0.9543269 step_1000: Loss is 0.15295759 [min=0.0018140086] max_accc=0.96594554 step_1500: Loss is 0.018035552 [min=0.0018140086] max_accc=0.9704527 step_2000: Loss is 0.029250022 [min=0.0010245014] max_accc=0.9765625 step_2500: Loss is 0.11875624 [min=7.5288175E-4] max_accc=0.9765625 step_3000: Loss is 0.046675075 [min=7.5288175E-4] max_accc=0.9765625 step_3500: Loss is 0.034442786 [min=4.3545474E-4] max_accc=0.97686297 ==========Evaluating The Trained Model============ accuracy = 0.9770633 Trained model successfully saved: ./model/lenet_tod_trained.ms ``` ## 示例程序详细说明 ### 示例程序结构 ```text train_lenet_java ├── lib ├── build.sh ├── model │ ├── lenet_export.py │ ├── prepare_model.sh │ └── train_utils.sh ├── pom.xml ├── prepare_and_run.sh ├── resources │   └── model │   └── lenet_tod.ms # LeNet训练模型 ├── src │   └── main │   └── java │   └── com │   └── mindspore │   └── lite │   ├── train_lenet │   │   ├── DataSet.java # MNIST数据集处理 │   │   ├── Main.java # Main函数 │   │   └── NetRunner.java # 整体训练流程 ``` ### 编写端侧推理代码 详细的Java接口使用请参考。 1. 加载MindSpore Lite模型文件,构建会话。 ```java MSConfig msConfig = new MSConfig(); // arg 0: DeviceType:DT_CPU -> 0 // arg 1: ThreadNum -> 2 // arg 2: cpuBindMode:NO_BIND -> 0 // arg 3: enable_fp16 -> false msConfig.init(0, 2, 0, false); session = new LiteSession(); System.out.println("Model path is " + modelPath); session = session.createTrainSession(modelPath, msConfig, false); session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f); ``` 2. 切换为训练模式,循环迭代,训练模型。 ```java session.train(); float min_loss = 1000; float max_acc = 0; for (int i = 0; i < cycles; i++) { for (int b = 0; b < virtualBatch; b++) { fillInputData(ds.getTrainData(), false); session.runGraph(); float loss = getLoss(); if (min_loss > loss) { min_loss = loss; } if ((b == 0) && ((i + 1) % 500 == 0)) { float acc = calculateAccuracy(10); // only test 10 batch size if (max_acc < acc) { max_acc = acc; } System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); } } } ``` 3. 切换为推理模式,执行推理,评估模型精度。 ```java session.eval(); for (long i = 0; i < tests; i++) { Vector labels = fillInputData(test_set, (maxTests == -1)); if (labels.size() != batchSize) { System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize); System.exit(1); } session.runGraph(); MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses)); if (outputsv == null) { System.err.println("can not find output tensor with size: " + batchSize * numOfClasses); System.exit(1); } float[] scores = outputsv.getFloatData(); for (int b = 0; b < batchSize; b++) { int max_idx = 0; float max_score = scores[(int) (numOfClasses * b)]; for (int c = 0; c < numOfClasses; c++) { if (scores[(int) (numOfClasses * b + c)] > max_score) { max_score = scores[(int) (numOfClasses * b + c)]; max_idx = c; } } if (labels.get(b) == max_idx) { accuracy += 1.0; } } } ``` 推理完成后,如果需要继续训练,需要切换为训练模式。 4. 保存训练模型。 ```java // arg 0: FileName // arg 1: model type MT_TRAIN -> 0 // arg 2: quantization type QT_DEFAULT -> 0 session.export(trainedFilePath, 0, 0) ``` 模型训练完成后,保存到指定路径,后续可以继续加载运行。