在 Android上使用 TensorFlow Lite

本文以 OjectDetection 例子为子(市面上一个很火的360智能跟拍云台)展开说明,TensorFlow Lite可以与 Android 8.1 中发布的神经网络 API 完美配合,即便在没有硬件加速时也能调用 CPU 处理,确保模型在不同设备上的运行。

整个工程大致的过程就是从控件 textureView 中以指定的长宽读取一个 Bitma p出来(也就是摄像头的实时画面),然后交给 classifier 的 classifyFrame 进行处理,返回一个结果,这个结果就是物体检测的结果,然后显示在手机屏幕上。

一、环境的搭建

我们可以使用 Android Studio 创建一个 Android 项目,一路默认就可以了,并不需要 C++ 的支持,因为是拿人家训练好的模型直接来用,不用去训练模型,即用到的 TensorFlow Lite 是 Java 代码的,开发起来非常方便。但需要特别的功能,就需要使用 TensorFlow 去训练模型了。

1.1 依赖

创建完成之后,在 app 目录下的 build.gradle 配置文件加上以下配置信息,如在 dependencies 下加上包的引用(每次运行都下载依赖):

    //依赖库
    implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }

对于 Android 有一个地方需要注意,必须在 app 模块的 build.gradle 中添加如下的语句,否则无法加载模型。

    //set no compress models
    aaptOptions {
        noCompress "tflite"
    }

1.2 模型文件配置

在 main 目录下创建 assets 文件夹,这个文件夹主要是存放 tflite 模型和 label 名称文件。



TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。

二、原始数据的获取

手机端的深度学习输入参数有视觉和听觉,即图像和声音,对于图像而言, Camera 是图像采集的唯一工具。因此需要了解 Camera2 的几个比较重要的类:

  • CameraManager: 管理手机上的所有摄像头设备,它的作用主要是获取摄像头列表和打开指定的摄像头;
  • CameraDevice: 具体的摄像头设备,它有一系列参数(预览尺寸、拍照尺寸等),可以通过 CameraManager 的 getCameraCharacteristics() 方法获取。它的作用主要是创建 CameraCaptureSession 和 CaptureRequest;
  • CameraCaptureSession: 相机捕获会话,用于处理拍照和预览的工作(很重要);
  • CaptureRequest: 捕获请求,定义输出缓冲区以及显示界面(TextureView 或 SurfaceView)等。

数据获取的过程:通过Camera 获取图片,然后使用对图片进行压缩,之后把图片转换成 ByteBuffer 格式的数据。

2.1 定义 AutoFitTextureView 作为预览界面

在布局文件中加入 AutoFitTextureView 控件,然后实现其监听事件

textureView = (AutoFitTextureView) view.findViewById(R.id.texture);

然后我们可以在OnResume()方法中设置监听 SurefaceTexture 的事件

textureView.setSurfaceTextureListener(surfaceTextureListener);

当SurefaceTexture准备好后会回调SurfaceTextureListener 的onSurfaceTextureAvailable()方法

TextureView.SurfaceTextureListener textureListener = new TextureView.SurfaceTextureListener() {
    @Override
    public void onSurfaceTextureAvailable(SurfaceTexture surface, int width, int height) {
        //当SurefaceTexture可用的时候,设置相机参数并打开相机
        setUpCameraOutputs(width, height);
        openCamera();
    }
};

2.2 设置相机参数

为了更好地预览,我们根据TextureView的尺寸设置预览尺寸,Camera2中使用CameraManager来管理摄像头

private void setUpCameraOutputs(int width, int height) {
    final Activity activity = getActivity();
    final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
    try {
      final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
      //获取StreamConfigurationMap,它是管理摄像头支持的所有输出格式和尺寸
      StreamConfigurationMap map = characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
      //根据TextureView的尺寸设置预览尺寸
      mPreviewSize = getOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height);
    } catch (CameraAccessException e) {
        e.printStackTrace();
    }
}

2.3 开启相机

Camera2 中打开相机也需要通过 CameraManager 类操作。

private void openCamera() {
    final Activity activity = getActivity();
    final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
    try {
      if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
        throw new RuntimeException("Time out waiting to lock camera opening.");
      }
      manager.openCamera(cameraId, stateCallback, backgroundHandler);
    } catch (final CameraAccessException e) {
      LOGGER.e(e, "Exception!");
    } catch (final InterruptedException e) {
      throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
    }
  }

实现StateCallback 接口,当相机打开后会回调onOpened方法,在这个方法里面开启预览

 private final CameraDevice.StateCallback stateCallback =
      new CameraDevice.StateCallback() {
        @Override
        public void onOpened(final CameraDevice cd) {
          // This method is called when the camera is opened.  We start camera preview here.
          cameraOpenCloseLock.release();
          cameraDevice = cd;
          //开启预览
          createCameraPreviewSession();
        }

        @Override
        public void onDisconnected(final CameraDevice cd) {
          cameraOpenCloseLock.release();
          cd.close();
          cameraDevice = null;
        }
        ......
      };

2.4 开启相机预览

我们使用 TextureView 显示相机预览数据,Camera2 的预览和拍照数据都是使用 CameraCaptureSession 会话来请求的。

private void createCameraPreviewSession() {
    try {
      final SurfaceTexture texture = textureView.getSurfaceTexture();
      assert texture != null;

      //设置TextureView的缓冲区大小
      texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());

      //获取Surface显示预览数据
      final Surface surface = new Surface(texture);

      //创建CaptureRequestBuilder,TEMPLATE_PREVIEW比表示预览请求
      previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
      previewRequestBuilder.addTarget(surface);

      // 使用ImageReader间接实现
      previewReader =
          ImageReader.newInstance(
              previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);

      previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
      previewRequestBuilder.addTarget(previewReader.getSurface());

          //创建相机捕获会话,第一个参数是捕获数据的输出Surface列表,第二个参数是CameraCaptureSession的状态回调接口,当它创建好后会回调onConfigured方法,第三个参数用来确定Callback在哪个线程执行,为null的话就在当前线程执行
      cameraDevice.createCaptureSession(
          Arrays.asList(surface, previewReader.getSurface()),
          new CameraCaptureSession.StateCallback() {

            @Override
            public void onConfigured(final CameraCaptureSession cameraCaptureSession) {
              // The camera is already closed
              if (null == cameraDevice) {
                return;
              }

              //创建捕获请求
              captureSession = cameraCaptureSession;
              try {
                // Auto focus should be continuous for camera preview.
                previewRequestBuilder.set(
                    CaptureRequest.CONTROL_AF_MODE,
                    CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
                // Flash is automatically enabled when necessary.
                previewRequestBuilder.set(
                    CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH);

                // Finally, we start displaying the camera preview.
                previewRequest = previewRequestBuilder.build();
                //设置反复捕获数据的请求,这样预览界面就会一直有数据显示
                captureSession.setRepeatingRequest(
                    previewRequest, captureCallback, backgroundHandler);
              } catch (final CameraAccessException e) {
                LOGGER.e(e, "Exception!");
              }
            }

            @Override
            public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) {
              showToast("Failed");
            }
          },
          null);
    } catch (final CameraAccessException e) {
      LOGGER.e(e, "Exception!");
    }
  }

2.5 拍照

Camera2 拍照也是通过 ImageReader 来实现的。

首先先做些准备工作,设置拍照参数,如方向、尺寸等

  /** Conversion from screen rotation to JPEG orientation. */
  private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
  static {
    ORIENTATIONS.append(Surface.ROTATION_0, 90);
    ORIENTATIONS.append(Surface.ROTATION_90, 0);
    ORIENTATIONS.append(Surface.ROTATION_180, 270);
    ORIENTATIONS.append(Surface.ROTATION_270, 180);
  }

 /** Callback for Camera2 API */
  @Override
  public void onImageAvailable(final ImageReader reader) {
    // We need wait until we have some size from onPreviewSizeChosen
    if (previewWidth == 0 || previewHeight == 0) {
      return;
    }
    if (rgbBytes == null) {
      rgbBytes = new int[previewWidth * previewHeight];
    }
    try {
      final Image image = reader.acquireLatestImage();

      if (image == null) {
        return;
      }

      if (isProcessingFrame) {
        image.close();
        return;
      }
      isProcessingFrame = true;

      final Plane[] planes = image.getPlanes();
      fillBytes(planes, yuvBytes);
      yRowStride = planes[0].getRowStride();
      final int uvRowStride = planes[1].getRowStride();
      final int uvPixelStride = planes[1].getPixelStride();

      imageConverter =
          new Runnable() {
            @Override
            public void run() {
              ImageUtils.convertYUV420ToARGB8888(
                  yuvBytes[0],
                  yuvBytes[1],
                  yuvBytes[2],
                  previewWidth,
                  previewHeight,
                  yRowStride,
                  uvRowStride,
                  uvPixelStride,
                  rgbBytes);
            }
          };

      postInferenceCallback =
          new Runnable() {
            @Override
            public void run() {
              image.close();
              isProcessingFrame = false;
            }
          };
      processImage();
    } catch (final Exception e) {
      LOGGER.e(e, "Exception!");
      Trace.endSection();
      return;
    }
  }

三、TensorFlow Lite 处理

3.1 加载训练模型

loadModelFile()方法是把模型文件读取成MappedByteBuffer,之后给Interpreter类初始化模型

// load infer model
    private void loadModel(String model) {
        try {
            tflite = new Interpreter(loadModelFile(model));
            Log.d(TAG, model + " model load success");
            //tflite.setNumThreads(4);
        } catch (IOException e) {
            Log.d(TAG, model + " model load fail");
            e.printStackTrace();
        }
    }


    /**
     * Memory-map the model file in Assets.
     */
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

得到一个对象tflite,之后就是使用这个对象来预测图像,同时可以使用这个对象设置一些参数。

我们先分析一下再 assets 目录下面怎么加载的?说白了就是新建一个 Interpreter 对象,就是加载模型。上面的方法都过时了,我们可以找到 Interpreter类,里面你会看到如下的方法

//第一个参数传tflite文件,第二个参数传一个Interpreter静态内部类对象
public Interpreter(@NonNull File modelFile, Interpreter.Options options) {
        this.wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}

//所以,我们自己项目里面加载模型,用如下方式即可
//file:///android_asset/labelmap.txt,  detect.tflite
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
tflite = new Interpreter(new File(""), options);

3.2 读取文件种分类标签对应的名称

读取文件种分类标签对应的名称,这个文件 labelmap.txt 跟模型一样存放在 assets 目录下,这个文件比较长,里面有对用的文件。

    private List<String> resultLabel = new ArrayList<>();

    try {
            AssetManager assetManager = getApplicationContext().getAssets();
            BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("labelmap.txt")));
            String readLine = null;
            while ((readLine = reader.readLine()) != null) {
                resultLabel.add(readLine);
            }
            reader.close();
        } catch (Exception e) {
            Log.e("labelCache", "error " + e);
        }

3.3 进行检测

执行run方法

tflite.run(in, out);

    Object[] inputArray = {imgData};
    Map<Integer, Object> outputMap = new HashMap<>();
    outputMap.put(0, outputLocations);
    outputMap.put(1, outputClasses);
    outputMap.put(2, outputScores);
    outputMap.put(3, numDetections);
    Trace.endSection();

    // Run the inference call.
    tfLite.runForMultipleInputsOutputs(inputArray, outputMap);

显示检测结果

    // Show the best detections.
    // after scaling them back to the input size.

    // You need to use the number of detections from the output and not the NUM_DETECTONS variable declared on top
      // because on some models, they don't always output the same total number of detections
      // For example, your model's NUM_DETECTIONS = 20, but sometimes it only outputs 16 predictions
      // If you don't use the output's numDetections, you'll get nonsensical data
    int numDetectionsOutput = Math.min(NUM_DETECTIONS, (int) numDetections[0]); // cast from float to integer, use min for safety

    final ArrayList<Recognition> recognitions = new ArrayList<>(numDetectionsOutput);
    for (int i = 0; i < numDetectionsOutput; ++i) {
      final RectF detection =
          new RectF(
              outputLocations[0][i][1] * inputSize,
              outputLocations[0][i][0] * inputSize,
              outputLocations[0][i][3] * inputSize,
              outputLocations[0][i][2] * inputSize);
      // SSD Mobilenet V1 Model assumes class 0 is background class
      // in label file and class labels start from 1 to number_of_classes+1,
      // while outputClasses correspond to class index from 0 to number_of_classes
      int labelOffset = 1;
      recognitions.add(
          new Recognition(
              "" + i,
              labels.get((int) outputClasses[0][i] + labelOffset),
              outputScores[0][i],  //最大概率或得分最高
              detection));
    }