【免费赠书】如何使用 Tensorflow Lite 库


【免费赠书】长按扫描上方二维码成为社区会员


填写会员资料完成后会返回一个序列号码,号码为1、10、20、30、40的朋友,可以免费获得《TensorFlow与卷积神经网络从算法入门到项目实战》图书一本,号码为50、60、70、80、90的朋友,可以免费获得《Python网络爬虫从入门到实践》图书一本,奖品图书由机械工业出版社提供。如果您发现序号为以上数字,请务必以您填写的邮箱发送邮寄信息(包括姓名、手机、地址)到工作人员邮箱:pythonpost@163.com,信息整理后获奖图书将在一周左右寄出。赠书样本请在文末查看,未获奖的朋友也可以在文末查看购书。


活动截止日期:8月20日(周二)晚上22:30。


Tensorflow Lite库是Tensorflow官方推荐使用的库,相比Tensorflow Mobile库,Tensorflow Lite库对应的二进制模型文件更小、依赖库更少并且Tensorflow Lite库拥有更好的性能。

一、模型转换

Tensorflow Mobile库中使用的模型文件并不能直接在Tensorflow Lite库中直接使用,因此,使用Tensorflow Lite库也需要将模型进行转换,示例代码如下所示:

注意:Windows系统中Tensorflow库中的TocoConverter存在BUG(官方修复中),因此以下代码请在Linux(或Mac)系统中最新Tensorflow(版本大于等于1.9)库中执行。

01  import tensorflow as tf
02  from nets.mobilenet import MobileNetV2  
03  lite = tf.contrib.lite  
04  def inference(input_tf,n_classes):  
05      net = MobileNetV2(n_classes=n_classes,depth_rate=1.0,is_training=False
06      output =  net.build_graph(input_tf)  
07      output = tf.nn.softmax(output) 
08      output = tf.nn.top_k(output, k=5, sorted=True
09      return output.indices,output.values
10  def main(): 
11      input_tf = tf.placeholder(tf.float32,shape=(1,64,64,3),name='input')
12      classes_tf,scores_tf = inference(input_tf, 3755)  
13      classes = tf.identity(classes_tf, name='classes')
14      scores = tf.identity(scores_tf, name='scores')
15      restore_saver = tf.train.Saver() 
16       
17      with tf.Session() as sess:
18          sess.run(tf.global_variables_initializer())
19          sess.run(tf.local_variables_initializer())  
20          restore_saver.restore(sess,'model_1/model-170000')
21         
22          converter = lite.TocoConverter.from_session(sess, [input_tf], [classes,scores])
23          tflite_model = converter.convert()
24          open("model/hw_model.tflite""wb").write(tflite_model)
25
26  main()

以上代码中,第4~9行功能为定义网络结构静态图,第11~20行加载模型参数到当前网络静态图中。第22行代码定义模型转换对象(TocoConverter对象),指定模型的输入节点和输出节点,并在23行中执行模型转换。执行上面代码后,在目录“model”中生成文件hw_model.tflite。

二、模型调用

在使用Tensorflow Lite库之前,需要填写Tensorflow Lite库依赖,并且指定tensorflow lite模型不被压缩。即在文件app/build.gradle文件的android模块和dependencies模块中添加如下加粗部分。

01  android {
02      #其他略
03      aaptOptions {
04          noCompress "tflite"
05      }
06  }
07
08  dependencies {
09      #其他依赖略...
10      compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
11  }

注意:使用Tensorflow Lite时,需要使用AndroidStudio 3.0以上版本。

有了Tensorflow Lite库后,接下来可以直接使用库相关接口函数。主要使用到Tensorflow Lite库中的org.tensorflow.lite.Interpreter类对象,类Interpreter常用的函数如下:

01  //构造函数,其中参数modelFile对应模型文件的File对象。
02  public Interpreter(@NonNull File modelFile) {...}
03
04  //构造函数,其中,
05  //参数modelFile对应模型文件的File对象。
06  //参数options为Options对象,用于配置线程数等相关配置参数。
07  public Interpreter(@NonNull File modelFile, Options options) {...}
08 
09  //构造函数,其中参数byteBuffer为存放模型数据的ByteBuffer对象
10  public Interpreter(@NonNull ByteBuffer byteBuffer) {...}
11
12  //构造函数,其中参数byteBuffer为存放模型数据的ByteBuffer对象
13  //参数options为Options对象,用于配置线程数等相关配置参数。
14  public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {... }
15
16  //执行模型,只能是一个输入节点和一个输出节点,其中,
17  //参数input是数组类型,或ByteBuffer类型。如果是数组类型,可以是int、float、long以及
18  //byte等数据类型。
19  //参数output是数组类型,或ByteBuffer类型。如果是数组类型,可以是int、float、long以及
20  //byte等数据类型。
21  public void run(@NonNull Object input, @NonNull Object output) {.. }
22
23  //执行模型,允许多个输入节点和多个输出节点
24  //参数inputs是多个输入数据组成的数组,其中每个输入数据可以是数组类型或ByteBuffer类25  //型。如果是数组类型,可以是int、float、long以及byte等数据类型。
26  //参数outputs为Map类型,其中key为输出节点的索引值,value为数组类型或者是ByteBuffer
27  //类型。
28  public void runForMultipleInputsOutputs( @NonNull Object[] inputs,
29                 @NonNull Map<Integer, Object> outputs)
 
{... }
30
31  //对模型中的第idx个输入做resize操作
32  public void resizeInput(int idx, @NonNull int[] dims) {... }
33
34  //获取输入节点的数量
35  public int getInputTensorCount() {...}
36
37  //获取指定名称的输入节点在模型输入中的索引值
38  public int getInputIndex(String opName) {... }
39
40  //返回指定索引的输入节点对象(Tensor对象)
41  public Tensor getInputTensor(int inputIndex) {... }
42
43  //获取输出节点数量
44  public int getOutputTensorCount() {...}
45
46  //获取指定名称的输出节点在模型输出中的索引值
47  public int getOutputIndex(String opName) {....}
48
49  //获取指定索引值处的输出节点对象(即Tensor对象)
50  public Tensor getOutputTensor(int outputIndex) {...}

注意:使用Interpreter的run和runForMultipleInputsOutputs函数时,使用ByteBuffer类型的输入数据执行速度比数组类型更快。

有了Interpreter提供的以上基础函数后,接下来调用Tensorflow Lite模型,使用Interpreter调用模型的基本步骤如下:

(1)加载模型参数,通过Interpreter类构造函数实现。
(2)执行网络模型,并传入输入数据和输出拷贝空间,通过函数run或函数runForMultipleInputsOutputs实现。

根据以上基本步骤,创建文件RunLiteModel.java,用于封装模型调用,示例代码如下所示:

01  package com.huachao.cnr_lite;
02  //import ...
03  public class RunLiteModel {
04      private String inputName;
05      private String[] outputNames;
06      private int inputWH;
07      private int[] inputIntData ;
08      private float[] inputFloatData ;
09
10      //定义Interpreter的配置选项Options对象
11      private final Interpreter.Options tfliteOptions = new Interpreter.Options();
12      //加载Tensorflow Lite模型
13      private MappedByteBuffer tfliteModel;
14      //定义Interpreter对象,用于执行Tensorflow Lite模型
15      protected Interpreter tflite;
16      //定义ByteBuffer对象,用于存放图片数据,并传入Tensorflow Lite模型输入节点中
17      protected ByteBuffer imgData = null;
18      //定义数组对象,用于存放Tensorflow Lite输出结果 
19      private int[][] classes = null;
20      private float[][] scores = null;
21
22      public RunLiteModel(String modelName, String inputName, String[] outputNames,
23                                            int inputWH, AssetManager assetMngr)
{
24          this.inputName=inputName;
25          this.outputNames=outputNames;
26          this.inputWH=inputWH;
27          this.inputIntData=new int[inputWH*inputWH];
28          this.inputFloatData = new float[inputWH*inputWH*3];
29          //从assets目录加载模型
30          tfliteModel = loadModelFile(assetMngr,modelName);
31          tflite = new Interpreter(tfliteModel, tfliteOptions);
32          //height*width*sizeof(float)*channel
33          imgData = ByteBuffer.allocateDirect( inputWH * inputWH * 4 * 3);
34          imgData.order(ByteOrder.nativeOrder());
35          classes = new int[1][5];
36          scores = new float[1][5];
37      }
38      public Map<String,Object> run(Bitmap bitmap) {
39          convertBitmapToByteBuffer(bitmap);
40          //如果只有一个输入和一个输出,则使用以下函数
41          //tflite.run(imgData, classes);
42          //本例中,有两个输出,因此使用runForMultipleInputsOutputs函数
43          Map<Integer, Object> outputs = new ArrayMap<>(2);
44          outputs.put(0, classes);
45          outputs.put(1, scores);
46          tflite.runForMultipleInputsOutputs(new Object[]{imgData}, outputs);
47          Map<String,Object> results=new HashMap<>();
48          results.put("classes",classes[0]);
49          results.put("scores",scores[0]);
50          return results;
51      }
52      //对图像做Resize
53      public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
54          //与12.2.2节中的放缩函数一致,代码略..
55      }
56
57      //将图片数据写入到ByteBuffer对象中
58      private void convertBitmapToByteBuffer(Bitmap bitmap) {
59          Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
60          if (imgData == null) {
61              return;
62          }
63          imgData.rewind();
64          bm.getPixels(inputIntData, 0, bm.getWidth(), 00, bm.getWidth(), bm.getHeight());
65          //将图片数据转为浮点数类型
66          int pixel = 0
67          for (int i = 0; i < inputWH; ++i) {
68              for (int j = 0; j < inputWH; ++j) {
69                  final int val = inputIntData[pixel++];
70                  imgData.putFloat((val >> 16) & 0xFF) ;
71                  imgData.putFloat((val >> 8) & 0xFF);
72                  imgData.putFloat(val & 0xFF);
73              }
74          } 
75      }
76
77      //从assets目录中读取Tensorflow Lite模型
78      private MappedByteBuffer loadModelFile(AssetManager assetMngr,String modelName) {
79          try {
80              AssetFileDescriptor fileDescriptor = assetMngr.openFd(modelName);
81              FileInputStream inputStream = 
82                                   new FileInputStream(fileDescriptor.getFileDescriptor());
83              FileChannel fileChannel = inputStream.getChannel();
84              long startOffset = fileDescriptor.getStartOffset();
85              long declaredLength = fileDescriptor.getDeclaredLength();
86              return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset,
87                                                                      declaredLength);
88          }catch (Exception e){
89              e.printStackTrace();
90          }
91          return null;
92      }
93  }

以上代码中,第22~37行为构造函数,指定模型的输入和输出名称、加载模型参数数据并初始化Interpreter对象以及分配输入和输出数据存储空间。第38~51行代码为模型调用接口封装,将传入的Bitmap转为ByteBuffer对象,并执行runForMultipleInputsOutputs函数将输入数据传入模型输入节点,获取模型输出数据并返回。

三、模型测试

与模型测试一样,将汉字字符与索引值映射关系文件char_list.txt存放在“assets”目录中,在MainActivity类中实现加载char_list.txt文件、撤退和清空按钮调用入口、监听手写笔画并将回调的Bitmap对象传入RunLiteModel对象中调用模型以及在界面中显示识别结果,调用Tensorflow Lite模型的MainActivity代码略。分别手写“测”、“试”两个字,识别效果如图所示。

图 使用Tensorflow Lite库对手写“测试”识别效果