基于Java接口实现端侧训练

Android Java 全流程 模型加载 模型训练 数据准备 初级 中级 高级

概述

本教程通过构建并部署Java版本的LeNet网络的训练,演示MindSpore Lite端侧训练Java接口的使用。 首先指导您在本地成功训练LeNet模型,然后讲解示例代码。

准备

环境要求

  • 系统环境:Linux x86_64,推荐使用Ubuntu 18.04.02LTS

  • 软件依赖

下载MindSpore并编译端侧训练Java包

首先克隆源码,然后编译MindSpore Lite端侧训练Java包,Linux指令如下:

git clone https://gitee.com/mindspore/mindspore.git
cd mindspore
bash build.sh -A java -ecpu -Ton -j8

更详细的编译说明,请参考编译MindSpore Lite章节。 本教程使用的示例源码在mindspore/lite/examples/train_lenet_java目录。

下载数据集

示例中的MNIST数据集由10类28*28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。

MNIST数据集官网下载地址:http://yann.lecun.com/exdb/mnist/,共4个下载链接,分别是训练数据、训练标签、测试数据和测试标签。

下载并解压到本地,解压后的训练和测试集分别存放于/PATH/MNIST_Data/train/PATH/MNIST_Data/test路径下。

目录结构如下:

MNIST_Data/
├── test
│   ├── t10k-images-idx3-ubyte
│   └── t10k-labels-idx1-ubyte
└── train
    ├── train-images-idx3-ubyte
    └── train-labels-idx1-ubyte

部署应用

运行依赖

在准备阶段,我们已经成功编译出MindSpore Lite端侧训练Java包。假设您的MindSpore源码路径为/codes/mindspore,对应编译出的Java包在/codes/mindspore/output目录。解压Java包并拷贝相关文件到示例程序目录。命令如下:

cd /codes/mindspore/output
tar xzf mindspore-lite-${version}-train-linux-x64-jar.tar.gz
mkdir ../mindspore/lite/examples/train_lenet_java/lib
cp mindspore-lite-${version}-train-linux-x64-jar/jar/* ../mindspore/lite/examples/train_lenet_java/lib/

构建与运行

  1. 首先进入示例工程所在目录,使用maven构建本示例。命令如下:

    cd /codes/mindspore/mindspore/lite/examples/train_lenet_java
    mvn package
    
  2. 运行示例程序,命令如下:

    cd /codes/mindspore/mindspore/lite/examples/train_lenet_java/target
    java -Djava.library.path=../lib/ -classpath .:./train_lenet_java.jar:../lib/mindspore-lite-java.jar com.mindspore.lite.train_lenet.Main ../resources/model/lenet_tod.ms /PATH/MNIST_Data/
    

    ../resources/model/lenet_tod.ms是示例工程中预置的LeNet训练模型,您也可以参考训练模型转换,自行转换出LeNet模型。

    /PATH/MNIST_Data/是MNIST数据集所在路径。

    示例运行结果如下:

    MindSpore Lite 1.2.0
    ==========Loading Model, Create Train Session=============
    batch_size: 32
    ==========Initing DataSet================
    train data cnt: 60000
    test data cnt: 10000
    ==========Training Model===================
    step_500: Loss is 0.05553353 [min=0.010149269] max_accc=0.9543269
    step_1000: Loss is 0.15295759 [min=0.0018140086] max_accc=0.96594554
    step_1500: Loss is 0.018035552 [min=0.0018140086] max_accc=0.9704527
    step_2000: Loss is 0.029250022 [min=0.0010245014] max_accc=0.9765625
    step_2500: Loss is 0.11875624 [min=7.5288175E-4] max_accc=0.9765625
    step_3000: Loss is 0.046675075 [min=7.5288175E-4] max_accc=0.9765625
    step_3500: Loss is 0.034442786 [min=4.3545474E-4] max_accc=0.97686297
    ==========Evaluating The Trained Model============
    accuracy = 0.9770633
    Trained model successfully saved: ../resources/model/lenet_tod_trained.ms
    

示例程序详细说明

示例程序结构

train_lenet_java
├── lib
├── pom.xml
├── resources
│   └── model
│       └── lenet_tod.ms   # LeNet训练模型
├── src
│   └── main
│       └── java
│           └── com
│               └── mindspore
│                   └── lite
│                       ├── train_lenet
│                       │   ├── DataSet.java      # MNIST数据集处理
│                       │   ├── Main.java         # Main函数
│                       │   └── NetRunner.java    # 整体训练流程

编写端侧推理代码

详细的Java接口使用请参考https://www.mindspore.cn/doc/api_java/zh-CN/master/index.html

  1. 加载MindSpore Lite模型文件,构建会话。

    MSConfig msConfig = new MSConfig();
    // arg 0: DeviceType:DT_CPU -> 0
    // arg 1: ThreadNum -> 2
    // arg 2: cpuBindMode:NO_BIND ->  0
    // arg 3: enable_fp16 -> false
    msConfig.init(0, 2, 0, false);
    session = new TrainSession();
    session.init(modelPath, msConfig);
    
  2. 切换为训练模式,循环迭代,训练模型。

    session.train();
    float min_loss = 1000;
    float max_acc = 0;
    for (int i = 0; i < cycles; i++) {
        fillInputData(ds.getTrainData(), false);
        session.runGraph();
        float loss = getLoss();
        if (min_loss > loss) {
            min_loss = loss;
        }
        if ((i + 1) % 500 == 0) {
            float acc = calculateAccuracy(10); // only test 10 batch size
            if (max_acc < acc) {
                max_acc = acc;
            }
            System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
        }
    }
    
  3. 切换为推理模式,执行推理,评估模型精度。

    session.eval();
    for (long i = 0; i < tests; i++) {
        Vector<Integer> labels = fillInputData(test_set, (maxTests == -1));
        if (labels.size() != batchSize) {
            System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize);
            System.exit(1);
        }
        session.runGraph();
        MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
        if (outputsv == null) {
            System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
            System.exit(1);
        }
        float[] scores = outputsv.getFloatData();
        for (int b = 0; b < batchSize; b++) {
            int max_idx = 0;
            float max_score = scores[(int) (numOfClasses * b)];
            for (int c = 0; c < numOfClasses; c++) {
                if (scores[(int) (numOfClasses * b + c)] > max_score) {
                    max_score = scores[(int) (numOfClasses * b + c)];
                    max_idx = c;
                }
            }
            if (labels.get(b) == max_idx) {
                accuracy += 1.0;
            }
        }
    }
    

    推理完成后,如果需要继续训练,需要切换为训练模式。

  4. 保存训练模型。

    session.saveToFile(trainedFilePath)
    

    模型训练完成后,保存到指定路径,后续可以继续加载运行。