Android中使用TensorFlow Lite实现图像分类

android studio 教程 | 2018-09-28 01:12

今日科技快讯

9月26日,对于近期媒体报道的阿里系阻挠滴滴收购ofo一事,网易科技向阿里巴巴方面进行求证时,被官方予以了否认。自8月份始,滴滴将收购ofo的新闻已在市场上传过几轮。到9月中下旬,有媒体报道指出,滴滴日前曾和ofo就收购一事达成协议,但最终因阿里系不同意而导致交易告吹。

作者简介

本篇来自 夜雨飘零 的投稿,分享了关于 Android中如何使用TensorFlow Lite实现图像分类,一起来看看!希望大家喜欢。

夜雨飘零 的博客地址:

TensorFlow Lite是一款专门针对移动设备的深度学习框架,移动设备深度学习框架是部署在手机或者树莓派等小型移动设备上的深度学习框架,可以使用训练好的模型在手机等设备上完成推理任务。这一类框架的出现,可以使得一些推理的任务可以在本地执行,不需要再调用服务器的网络接口,大大减少了预测时间。在前几篇文章中已经介绍了百度的paddle-mobile,小米的mace,还有腾讯的ncnn。这在本章中我们将介绍谷歌的TensorFlow Lite。

TensorFlow Lite的GitHub地址:

转换模型

手机上执行预测,首先需要一个训练好的模型,这个模型不能是TensorFlow原来格式的模型,TensorFlow Lite使用的模型格式是另一种格式的模型。

下面就介绍如何使用这个格式的模型。 获取模型主要有两种方法,第一种是在训练的时候就保存tflite模型,另外一种就是使用其他格式的TensorFlow模型转换成tflite模型。

最方便的就是在训练的时候保存tflite格式的模型,主要是使用到tf.contrib.lite.toco_convert()接口,下面就是一个简单的例子:

import tensorflow as tfimg = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])out = tf.identity(val, name="out")with tf.Session() as sess:  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])  open("converteds_model.tflite", "wb").write(tflite_model)最后获得的converteds_model.tflite文件就可以直接在TensorFlow Lite上使用。

第二种就是把tensorflow保存的其他模型转换成tflite,我们可以在以下的链接下载模型。tensorflow模型地址如下所示:

上面提供的模型同时也包括了tflite模型,我们可以直接拿来使用,但是我们也可以使用其他格式的模型来转换。比如我们下载一个mobilenet_v1_1.0_224.tgz,解压之后获得以下文件:

mobilenet_v1_1.0_224.ckpt.data-00000-of-00001  mobilenet_v1_1.0_224_eval.pbtxt  mobilenet_v1_1.0_224.tflitemobilenet_v1_1.0_224.ckpt.index                mobilenet_v1_1.0_224_frozen.pbmobilenet_v1_1.0_224.ckpt.meta                 mobilenet_v1_1.0_224_info.txt首先要安装Bazel,可以参考:

只需要完成Installing using binary installer这一部分即可。然后克隆TensorFlow的源码:

git clone 接着编译转换工具,这个编译时间可能比较长:

cd tensorflow/bazel build tensorflow/python/tools:freeze_graphbazel build tensorflow/contrib/lite/toco:toco获得到转换工具之后,我们就可以开始转换模型了,以下操作是冻结图。

input_graph对应的是.pb文件;

input_checkpoint对应的是mobilenet_v1_1.0_224.ckpt.data-00000-of-00001,但是在使用的使用是去掉后缀名的。

output_node_names这个可以在mobilenet_v1_1.0_224_info.txt中获取。

不过要注意的是我们下载的模型已经是冻结过来,所以不用再执行这个操作。但如果是其他的模型,要先冻结图,然后再执行之后的操作。

./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \  --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \  --input_binary=true \  --output_graph=/tmp/frozen_mobilenet_v1_224.pb \  --output_node_names=MobilenetV1/Predictions/Reshape_1以下操作就是把已经冻结的图转换成.tflite:

input_file是已经冻结的图;

output_file是转换后输出的路径;

output_arrays这个可以在mobilenet_v1_1.0_224_info.txt中获取;

input_shapes这个是预测数据的shape

./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb \  --input_format=TENSORFLOW_GRAPHDEF \  --output_format=TFLITE \  --output_file=/tmp/mobilenet_v1_1.0_224.tflite \  --inference_type=FLOAT \  --input_type=FLOAT \  --input_arrays=input \  --output_arrays=MobilenetV1/Predictions/Reshape_1 \  --input_shapes=1,224,224,3经过上面的步骤就可以获取到mobilenet_v1_1.0_224.tflite模型了,之后我们会在Android项目中使用它。

开发Android项目

有了上面的模型之后,我们就使用Android Studio创建一个Android项目,一路默认就可以了,并不需要C++的支持,因为我们使用到的TensorFlow Lite是Java代码的,开发起来非常方便。

1、创建完成之后,在app目录下的build.gradle配置文件加上以下配置信息: 在dependencies下加上包的引用,第一个是图片加载框架Glide,第二个就是我们这个项目的核心TensorFlow Lite:

implementation 'com.github.bumptech.glide:glide:4.3.1'implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'然后在android下加上以下代码,这个主要是限制不要对tensorflow lite的模型进行压缩,压缩之后就无法加载模型了:

//set no compress models aaptOptions {    noCompress "tflite" }在main目录下创建assets文件夹,这个文件夹主要是存放tflite模型和label名称文件。

以下是主界面的代码MainActivity.java,这个代码比较长,我们来分析这段代码,重要的方法介绍如下:

loadModelFile()方法是把模型文件读取成MappedByteBuffer,之后给Interpreter类初始化模型,这个模型存放在main的assets目录下。

load_model()方法是加载模型,并得到一个对象tflite,之后就是使用这个对象来预测图像,同时可以使用这个对象设置一些参数,比如设置使用的线程数量tflite.setNumThreads(4);

showDialog()方法是显示弹窗,通过这个弹窗的选择不同的模型。

readCacheLabelFromLocalFile()方法是读取文件种分类标签对应的名称,这个文件比较长,可以参考这篇文章获取标签名称,也可以下载笔者的项目,里面有对用的文件。这个文件cacheLabel.txt跟模型一样存放在assets目录下。

predict_image()方法是预测图片并显示结果的,预测的流程是:获取图片的路径,然后使用对图片进行压缩,之后把图片转换成ByteBuffer格式的数据,最后调用tflite.run()方法进行预测。

get_max_result()方法是获取最大概率的标签。

start_camera()方法是启动相机拍照并返回图片的路径,兼容了Android 7.0。 use_photo()方法是打开相册,获取选择的图片的URI。

get_path_from_URI()方法是把图片的URI转换成图片路径。

getScaledMatrix()方法是把图片的Bitmap格式转换成TensorFlow Lite所需的数据格式。

getScaleBitmap()方法是压缩图片,防止内存溢出。

package com.yeyupiaoling.testtflite;import android.app.Activity;import android.content.Context;import android.content.Intent;import android.database.Cursor;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.net.Uri;import android.os.Build;import android.os.Environment;import android.provider.MediaStore;import android.support.v4.content.FileProvider;import android.util.Log;import java.io.File;import java.io.IOException;import java.nio.ByteBuffer;import java.nio.ByteOrder;public class PhotoUtil {    // start camera    public static String start_camera(Activity activity, int requestCode) {        Uri imageUri;        // save image in cache path        File outputImage = new File(Environment.getExternalStorageDirectory().getAbsolutePath()                + "/lite_mobile/", System.currentTimeMillis() + ".jpg");        Log.d("outputImage", outputImage.getAbsolutePath());        try {            if (outputImage.exists()) {                outputImage.delete();            }            File out_path = new File(Environment.getExternalStorageDirectory().getAbsolutePath()                    + "/lite_mobile/");            if (!out_path.exists()) {                out_path.mkdirs();            }            outputImage.createNewFile();        } catch (IOException e) {            e.printStackTrace();        }        if (Build.VERSION.SDK_INT >= 24) {            // compatible with Android 7.0 or over            imageUri = FileProvider.getUriForFile(activity,                    "com.yeyupiaoling.testtflite.fileprovider", outputImage);        } else {            imageUri = Uri.fromFile(outputImage);        }        // set system camera Action        Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);        intent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);        // set save photo path        intent.putExtra(MediaStore.EXTRA_OUTPUT, imageUri);        // set photo quality, min is 0, max is 1        intent.putExtra(MediaStore.EXTRA_VIDEO_QUALITY, 0);        activity.startActivityForResult(intent, requestCode);        // return image absolute path        return outputImage.getAbsolutePath();    }    // get picture in photo    public static void use_photo(Activity activity, int requestCode) {        Intent intent = new Intent(Intent.ACTION_PICK);        intent.setType("image/*");        activity.startActivityForResult(intent, requestCode);    }    // get photo from Uri    public static String get_path_from_URI(Context context, Uri uri) {        String result;        Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);        if (cursor == null) {            result = uri.getPath();        } else {            cursor.moveToFirst();            int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);            result = cursor.getString(idx);            cursor.close();        }        return result;    }    // TensorFlow model,get predict data    public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {        ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);        imgData.order(ByteOrder.nativeOrder());        // get image pixel        int[] pixels = new int[ddims[2] * ddims[3]];        Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);        bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]);        int pixel = 0;        for (int i = 0; i < ddims[2]; ++i) {            for (int j = 0; j < ddims[3]; ++j) {                final int val = pixels[pixel++];                imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));                imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));                imgData.putFloat((((val & 0xFF) - 128f) / 128f));            }        }        if (bm.isRecycled()) {            bm.recycle();        }        return imgData;    }    // compress picture    public static Bitmap getScaleBitmap(String filePath) {        BitmapFactory.Options opt = new BitmapFactory.Options();        opt.inJustDecodeBounds = true;        BitmapFactory.decodeFile(filePath, opt);        int bmpWidth = opt.outWidth;        int bmpHeight = opt.outHeight;        int maxSize = 500;        // compress picture with inSampleSize        opt.inSampleSize = 1;        while (true) {            if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) {                break;            }            opt.inSampleSize *= 2;        }        opt.inJustDecodeBounds = false;        return BitmapFactory.decodeFile(filePath, opt);    }}AndroidManifest.xml下加上申请的权限,用到了相机和读取外部存储的内存:

<uses-permission android:name="android.permission.CAMERA"/>    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>然后还要在application下加上以下的配置信息,这个主要是为了兼容Android 7.0的相机:

<!-- FileProvider配置访问路径,适配7.0及其以上 -->        <provider            android:name="android.support.v4.content.FileProvider"            android:authorities="com.yeyupiaoling.testtflite.fileprovider"            android:exported="false"            android:grantUriPermissions="true">            <meta-data                android:name="android.support.FILE_PROVIDER_PATHS"                android:resource="@xml/file_paths"/>        </provider>之后在res创建一个xml目录,然后创建一个file_paths.xml文件,在这个文件中加上以下代码,这个是我们拍照之后图片存放的位置:

<?xml version="1.0" encoding="utf-8"?><resources>    <external-path        name="images"        path="lite_mobile/" /></resources>主界面布局代码activity_main.xml:

<?xml version="1.0" encoding="utf-8"?><RelativeLayout xmlns:android=""    xmlns:app="http://schemas.android.com/apk/res-auto"    xmlns:tools="http://schemas.android.com/tools"    android:layout_width="match_parent"    android:layout_height="match_parent"    tools:context=".MainActivity">    <LinearLayout        android:id="@+id/btn1_ll"        android:layout_width="match_parent"        android:layout_height="wrap_content"        android:layout_alignParentBottom="true"        android:orientation="horizontal">        <Button            android:id="@+id/use_photo"            android:layout_width="0dp"            android:layout_height="wrap_content"            android:layout_weight="1"            android:text="相册" />        <Button            android:id="@+id/start_camera"            android:layout_width="0dp"            android:layout_height="wrap_content"            android:layout_weight="1"            android:text="拍照" />    </LinearLayout>    <LinearLayout        android:id="@+id/btn2_ll"        android:layout_width="match_parent"        android:layout_height="wrap_content"        android:layout_above="@id/btn1_ll"        android:orientation="horizontal">        <Button            android:id="@+id/load_model"            android:layout_width="0dp"            android:layout_height="wrap_content"            android:layout_weight="1"            android:text="加载模型" />    </LinearLayout>    <TextView        android:id="@+id/result_text"        android:layout_width="match_parent"        android:layout_height="150dp"        android:layout_above="@id/btn2_ll"        android:hint="预测结果会在这里显示"        android:inputType="textMultiLine"        android:textSize="16sp"        tools:ignore="TextViewEdits" />    <ImageView        android:id="@+id/show_image"        android:layout_width="match_parent"        android:layout_height="match_parent"        android:layout_above="@id/result_text"        android:layout_alignParentTop="true" /></RelativeLayout>以下就是效果图片:

上面已经提高了全部代码,这里为了方便读者调试,项目地址如下所示:

使用Android Studio打开即可。