【免费赠书】长按扫描上方二维码成为社区会员
填写会员资料完成后会返回一个序列号码,号码为1、10、20、30、40的朋友,可以免费获得《TensorFlow与卷积神经网络从算法入门到项目实战》图书一本,号码为50、60、70、80、90的朋友,可以免费获得《Python网络爬虫从入门到实践》图书一本,奖品图书由机械工业出版社提供。如果您发现序号为以上数字,请务必以您填写的邮箱发送邮寄信息(包括姓名、手机、地址)到工作人员邮箱:pythonpost@163.com,信息整理后获奖图书将在一周左右寄出。赠书样本请在文末查看,未获奖的朋友也可以在文末查看购书。
活动截止日期:8月20日(周二)晚上22:30。
Tensorflow Lite库是Tensorflow官方推荐使用的库,相比Tensorflow Mobile库,Tensorflow Lite库对应的二进制模型文件更小、依赖库更少并且Tensorflow Lite库拥有更好的性能。
一、模型转换
Tensorflow Mobile库中使用的模型文件并不能直接在Tensorflow Lite库中直接使用,因此,使用Tensorflow Lite库也需要将模型进行转换,示例代码如下所示:
注意:Windows系统中Tensorflow库中的TocoConverter存在BUG(官方修复中),因此以下代码请在Linux(或Mac)系统中最新Tensorflow(版本大于等于1.9)库中执行。
01 import tensorflow as tf
02 from nets.mobilenet import MobileNetV2
03 lite = tf.contrib.lite
04 def inference(input_tf,n_classes):
05 net = MobileNetV2(n_classes=n_classes,depth_rate=1.0,is_training=False)
06 output = net.build_graph(input_tf)
07 output = tf.nn.softmax(output)
08 output = tf.nn.top_k(output, k=5, sorted=True)
09 return output.indices,output.values
10 def main():
11 input_tf = tf.placeholder(tf.float32,shape=(1,64,64,3),name='input')
12 classes_tf,scores_tf = inference(input_tf, 3755)
13 classes = tf.identity(classes_tf, name='classes')
14 scores = tf.identity(scores_tf, name='scores')
15 restore_saver = tf.train.Saver()
16
17 with tf.Session() as sess:
18 sess.run(tf.global_variables_initializer())
19 sess.run(tf.local_variables_initializer())
20 restore_saver.restore(sess,'model_1/model-170000')
21
22 converter = lite.TocoConverter.from_session(sess, [input_tf], [classes,scores])
23 tflite_model = converter.convert()
24 open("model/hw_model.tflite", "wb").write(tflite_model)
25
26 main()
以上代码中,第4~9行功能为定义网络结构静态图,第11~20行加载模型参数到当前网络静态图中。第22行代码定义模型转换对象(TocoConverter对象),指定模型的输入节点和输出节点,并在23行中执行模型转换。执行上面代码后,在目录“model”中生成文件hw_model.tflite。
二、模型调用
在使用Tensorflow Lite库之前,需要填写Tensorflow Lite库依赖,并且指定tensorflow lite模型不被压缩。即在文件app/build.gradle文件的android模块和dependencies模块中添加如下加粗部分。
01 android {
02 #其他略
03 aaptOptions {
04 noCompress "tflite"
05 }
06 }
07
08 dependencies {
09 #其他依赖略...
10 compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
11 }
注意:使用Tensorflow Lite时,需要使用AndroidStudio 3.0以上版本。
有了Tensorflow Lite库后,接下来可以直接使用库相关接口函数。主要使用到Tensorflow Lite库中的org.tensorflow.lite.Interpreter类对象,类Interpreter常用的函数如下:
01 //构造函数,其中参数modelFile对应模型文件的File对象。
02 public Interpreter(@NonNull File modelFile) {...}
03
04 //构造函数,其中,
05 //参数modelFile对应模型文件的File对象。
06 //参数options为Options对象,用于配置线程数等相关配置参数。
07 public Interpreter(@NonNull File modelFile, Options options) {...}
08
09 //构造函数,其中参数byteBuffer为存放模型数据的ByteBuffer对象
10 public Interpreter(@NonNull ByteBuffer byteBuffer) {...}
11
12 //构造函数,其中参数byteBuffer为存放模型数据的ByteBuffer对象
13 //参数options为Options对象,用于配置线程数等相关配置参数。
14 public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {... }
15
16 //执行模型,只能是一个输入节点和一个输出节点,其中,
17 //参数input是数组类型,或ByteBuffer类型。如果是数组类型,可以是int、float、long以及
18 //byte等数据类型。
19 //参数output是数组类型,或ByteBuffer类型。如果是数组类型,可以是int、float、long以及
20 //byte等数据类型。
21 public void run(@NonNull Object input, @NonNull Object output) {.. }
22
23 //执行模型,允许多个输入节点和多个输出节点
24 //参数inputs是多个输入数据组成的数组,其中每个输入数据可以是数组类型或ByteBuffer类25 //型。如果是数组类型,可以是int、float、long以及byte等数据类型。
26 //参数outputs为Map类型,其中key为输出节点的索引值,value为数组类型或者是ByteBuffer
27 //类型。
28 public void runForMultipleInputsOutputs( @NonNull Object[] inputs,
29 @NonNull Map<Integer, Object> outputs) {... }
30
31 //对模型中的第idx个输入做resize操作
32 public void resizeInput(int idx, @NonNull int[] dims) {... }
33
34 //获取输入节点的数量
35 public int getInputTensorCount() {...}
36
37 //获取指定名称的输入节点在模型输入中的索引值
38 public int getInputIndex(String opName) {... }
39
40 //返回指定索引的输入节点对象(Tensor对象)
41 public Tensor getInputTensor(int inputIndex) {... }
42
43 //获取输出节点数量
44 public int getOutputTensorCount() {...}
45
46 //获取指定名称的输出节点在模型输出中的索引值
47 public int getOutputIndex(String opName) {....}
48
49 //获取指定索引值处的输出节点对象(即Tensor对象)
50 public Tensor getOutputTensor(int outputIndex) {...}
注意:使用Interpreter的run和runForMultipleInputsOutputs函数时,使用ByteBuffer类型的输入数据执行速度比数组类型更快。
有了Interpreter提供的以上基础函数后,接下来调用Tensorflow Lite模型,使用Interpreter调用模型的基本步骤如下:
(1)加载模型参数,通过Interpreter类构造函数实现。
(2)执行网络模型,并传入输入数据和输出拷贝空间,通过函数run或函数runForMultipleInputsOutputs实现。
根据以上基本步骤,创建文件RunLiteModel.java,用于封装模型调用,示例代码如下所示:
01 package com.huachao.cnr_lite;
02 //import ...
03 public class RunLiteModel {
04 private String inputName;
05 private String[] outputNames;
06 private int inputWH;
07 private int[] inputIntData ;
08 private float[] inputFloatData ;
09
10 //定义Interpreter的配置选项Options对象
11 private final Interpreter.Options tfliteOptions = new Interpreter.Options();
12 //加载Tensorflow Lite模型
13 private MappedByteBuffer tfliteModel;
14 //定义Interpreter对象,用于执行Tensorflow Lite模型
15 protected Interpreter tflite;
16 //定义ByteBuffer对象,用于存放图片数据,并传入Tensorflow Lite模型输入节点中
17 protected ByteBuffer imgData = null;
18 //定义数组对象,用于存放Tensorflow Lite输出结果
19 private int[][] classes = null;
20 private float[][] scores = null;
21
22 public RunLiteModel(String modelName, String inputName, String[] outputNames,
23 int inputWH, AssetManager assetMngr){
24 this.inputName=inputName;
25 this.outputNames=outputNames;
26 this.inputWH=inputWH;
27 this.inputIntData=new int[inputWH*inputWH];
28 this.inputFloatData = new float[inputWH*inputWH*3];
29 //从assets目录加载模型
30 tfliteModel = loadModelFile(assetMngr,modelName);
31 tflite = new Interpreter(tfliteModel, tfliteOptions);
32 //height*width*sizeof(float)*channel
33 imgData = ByteBuffer.allocateDirect( inputWH * inputWH * 4 * 3);
34 imgData.order(ByteOrder.nativeOrder());
35 classes = new int[1][5];
36 scores = new float[1][5];
37 }
38 public Map<String,Object> run(Bitmap bitmap) {
39 convertBitmapToByteBuffer(bitmap);
40 //如果只有一个输入和一个输出,则使用以下函数
41 //tflite.run(imgData, classes);
42 //本例中,有两个输出,因此使用runForMultipleInputsOutputs函数
43 Map<Integer, Object> outputs = new ArrayMap<>(2);
44 outputs.put(0, classes);
45 outputs.put(1, scores);
46 tflite.runForMultipleInputsOutputs(new Object[]{imgData}, outputs);
47 Map<String,Object> results=new HashMap<>();
48 results.put("classes",classes[0]);
49 results.put("scores",scores[0]);
50 return results;
51 }
52 //对图像做Resize
53 public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
54 //与12.2.2节中的放缩函数一致,代码略..
55 }
56
57 //将图片数据写入到ByteBuffer对象中
58 private void convertBitmapToByteBuffer(Bitmap bitmap) {
59 Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
60 if (imgData == null) {
61 return;
62 }
63 imgData.rewind();
64 bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
65 //将图片数据转为浮点数类型
66 int pixel = 0;
67 for (int i = 0; i < inputWH; ++i) {
68 for (int j = 0; j < inputWH; ++j) {
69 final int val = inputIntData[pixel++];
70 imgData.putFloat((val >> 16) & 0xFF) ;
71 imgData.putFloat((val >> 8) & 0xFF);
72 imgData.putFloat(val & 0xFF);
73 }
74 }
75 }
76
77 //从assets目录中读取Tensorflow Lite模型
78 private MappedByteBuffer loadModelFile(AssetManager assetMngr,String modelName) {
79 try {
80 AssetFileDescriptor fileDescriptor = assetMngr.openFd(modelName);
81 FileInputStream inputStream =
82 new FileInputStream(fileDescriptor.getFileDescriptor());
83 FileChannel fileChannel = inputStream.getChannel();
84 long startOffset = fileDescriptor.getStartOffset();
85 long declaredLength = fileDescriptor.getDeclaredLength();
86 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset,
87 declaredLength);
88 }catch (Exception e){
89 e.printStackTrace();
90 }
91 return null;
92 }
93 }
以上代码中,第22~37行为构造函数,指定模型的输入和输出名称、加载模型参数数据并初始化Interpreter对象以及分配输入和输出数据存储空间。第38~51行代码为模型调用接口封装,将传入的Bitmap转为ByteBuffer对象,并执行runForMultipleInputsOutputs函数将输入数据传入模型输入节点,获取模型输出数据并返回。
三、模型测试
与模型测试一样,将汉字字符与索引值映射关系文件char_list.txt存放在“assets”目录中,在MainActivity类中实现加载char_list.txt文件、撤退和清空按钮调用入口、监听手写笔画并将回调的Bitmap对象传入RunLiteModel对象中调用模型以及在界面中显示识别结果,调用Tensorflow Lite模型的MainActivity代码略。分别手写“测”、“试”两个字,识别效果如图所示。
图 使用Tensorflow Lite库对手写“测试”识别效果