简体   繁体   English

在 android 中使用 Tensorflow 模型

[英]Use Tensorflow model in android

I have a Tensorflow model and I have converted it to ".tflite" but I don't know the way how to implement it on android.我有一个 Tensorflow 模型,我已将其转换为“.tflite”,但我不知道如何在 android 上实现它。 I followed the TensorFlow guidelines to implement it in android but since there is no XML code given the TensorFlow website I am struggling to connect it with the front end (XML).我遵循 TensorFlow 指南在 android 中实现它,但由于没有给定 TensorFlow 网站的 XML 代码,我正在努力将它与前端 (XML) 连接起来。 I need a clear explanation of how to use my model in android studio using java.我需要清楚地解释如何使用 java 在 android studio 中使用我的模型。

I followed the official instructions given in the TensorFlow website to implement the model in android.我按照 TensorFlow 网站中给出的官方说明在 android 中实现模型。

A sample code of how to implement object detection based on tflite model from Tensorflow.如何实现基于tflite模型的对象检测的示例代码。 I suppose these kinds of answers are not the best answers, but I happened to have a simple example of your exact problem.我想这些类型的答案不是最好的答案,但我碰巧有一个关于您的确切问题的简单示例。

Note: it does detect objects and outputs their labels into standard output using Log.d .注意:它确实检测对象并使用Log.d它们的标签输出到标准输出中。 No boxes or labels will be drawn around detected images.不会在检测到的图像周围绘制框或标签。

Download started models and labels from here . 从这里下载开始的模型和标签。 Put them into the assets folder of your project.将它们放入项目的assets文件夹中。

Java爪哇

import android.content.pm.PackageManager;
import android.media.Image;
import android.os.Bundle;
import android.util.Log;
import android.widget.Toast;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.Camera;
import androidx.camera.core.CameraSelector;
import androidx.camera.core.ExperimentalGetImage;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.lifecycle.ProcessCameraProvider;
import androidx.camera.view.PreviewView;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.mlkit.common.model.LocalModel;
import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.objects.DetectedObject;
import com.google.mlkit.vision.objects.ObjectDetection;
import com.google.mlkit.vision.objects.ObjectDetector;
import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.ExecutionException;

public class ActivityExample extends AppCompatActivity {
    private ListenableFuture<ProcessCameraProvider> cameraProviderFuture;
    private ObjectDetector objectDetector;
    private PreviewView prevView;
    private List<String> labels;

    private int REQUEST_CODE_PERMISSIONS = 101;
    private String[] REQUIRED_PERMISSIONS =
            new String[]{"android.permission.CAMERA"};

    @Override
    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        setContentView(R.layout.activity_fullscreen);
        prevView = findViewById(R.id.viewFinder);

        prepareObjectDetector();
        prepareLabels();

        if (allPermissionsGranted()) {
            startCamera();
        } else {
            ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
        }
    }

    private void prepareLabels() {
        try {
            InputStreamReader reader = new InputStreamReader(getAssets().open("labels_mobilenet_quant_v1_224.txt"));
            labels = readLines(reader);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private List<String> readLines(InputStreamReader reader) {
        BufferedReader bufferedReader = new BufferedReader(reader, 8 * 1024);
        Iterator<String> iterator = new LinesSequence(bufferedReader);

        ArrayList<String> list = new ArrayList<>();

        while (iterator.hasNext()) {
            list.add(iterator.next());
        }

        return list;
    }

    private void prepareObjectDetector() {
        CustomObjectDetectorOptions options = new CustomObjectDetectorOptions.Builder(loadModel("mobilenet_v1_1.0_224_quant.tflite"))
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableMultipleObjects()
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();
        objectDetector = ObjectDetection.getClient(options);
    }

    private LocalModel loadModel(String assetFileName) {
        return new LocalModel.Builder()
                .setAssetFilePath(assetFileName)
                .build();
    }

    private void startCamera() {
        cameraProviderFuture = ProcessCameraProvider.getInstance(this);
        cameraProviderFuture.addListener(() -> {
            try {
                ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
                bindPreview(cameraProvider);
            } catch (ExecutionException e) {
                // No errors need to be handled for this Future.
                // This should never be reached.
            } catch (InterruptedException e) {
            }
        }, ContextCompat.getMainExecutor(this));
    }

    private void bindPreview(ProcessCameraProvider cameraProvider) {
        Preview preview = new Preview.Builder().build();
        CameraSelector cameraSelector = new CameraSelector.Builder()
                .requireLensFacing(CameraSelector.LENS_FACING_BACK)
                .build();
        ImageAnalysis imageAnalysis = new ImageAnalysis.Builder()
                .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
                .build();
        YourAnalyzer yourAnalyzer = new YourAnalyzer();
        yourAnalyzer.setObjectDetector(objectDetector, labels);
        imageAnalysis.setAnalyzer(
                ContextCompat.getMainExecutor(this),
                yourAnalyzer);

        Camera camera =
                cameraProvider.bindToLifecycle(
                        this,
                        cameraSelector,
                        preview,
                        imageAnalysis
                );

        preview.setSurfaceProvider(prevView.createSurfaceProvider(camera.getCameraInfo()));
    }

    private Boolean allPermissionsGranted() {
        for (String permission : REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(
                    this,
                    permission
            ) != PackageManager.PERMISSION_GRANTED
            ) {
                return false;
            }
        }
        return true;
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera();
            } else {
                Toast.makeText(this, "Permissions not granted by the user.", Toast.LENGTH_SHORT)
                        .show();
                finish();
            }
        }
    }

    private static class YourAnalyzer implements ImageAnalysis.Analyzer {
        private ObjectDetector objectDetector;
        private List<String> labels;

        public void setObjectDetector(ObjectDetector objectDetector, List<String> labels) {
            this.objectDetector = objectDetector;
            this.labels = labels;
        }

        @Override
        @ExperimentalGetImage
        public void analyze(@NonNull ImageProxy imageProxy) {
            Image mediaImage = imageProxy.getImage();
            if (mediaImage != null) {
                InputImage image = InputImage.fromMediaImage(
                        mediaImage,
                        imageProxy.getImageInfo().getRotationDegrees()
                );
                objectDetector
                        .process(image)
                        .addOnFailureListener(e -> imageProxy.close())
                        .addOnSuccessListener(detectedObjects -> {
                            // list of detectedObjects has all the information you need
                            StringBuilder builder = new StringBuilder();
                            for (DetectedObject detectedObject : detectedObjects) {
                                for (DetectedObject.Label label : detectedObject.getLabels()) {
                                    builder.append(labels.get(label.getIndex()));
                                    builder.append("\n");
                                }
                            }
                            Log.d("OBJECTS DETECTED", builder.toString().trim());
                            imageProxy.close();
                        });
            }
        }
    }


    static class LinesSequence implements Iterator<String> {
        private BufferedReader reader;
        private String nextValue;
        private Boolean done = false;

        public LinesSequence(BufferedReader reader) {
            this.reader = reader;
        }

        @Override
        public boolean hasNext() {
            if (nextValue == null && !done) {
                try {
                    nextValue = reader.readLine();
                } catch (IOException e) {
                    e.printStackTrace();
                    nextValue = null;
                }
                if (nextValue == null) done = true;
            }
            return nextValue != null;
        }

        @Override
        public String next() {
            if (!hasNext()) {
                throw new NoSuchElementException();
            }
            String answer = nextValue;
            nextValue = null;
            return answer;
        }
    }
}

XML layout XML 布局

<?xml version="1.0" encoding="utf-8"?>
<androidx.camera.view.PreviewView
    xmlns:android="http://schemas.android.com/apk/res/android"
    android:id="@+id/viewFinder"
    android:layout_width="match_parent"
    android:layout_height="match_parent" />

Gradle file configuration Gradle文件配置

android {
    ...
    aaptOptions {
        noCompress "tflite"  // Your model\'s file extension: "tflite", "lite", etc.
    }
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}


dependencies {
    ...
    
    implementation 'com.google.mlkit:object-detection-custom:16.0.0'
    def camerax_version = "1.0.0-beta03"
    // CameraX core library using camera2 implementation
    implementation "androidx.camera:camera-camera2:$camerax_version"
    // CameraX Lifecycle Library
    implementation "androidx.camera:camera-lifecycle:$camerax_version"
    // CameraX View class
    implementation "androidx.camera:camera-view:1.0.0-alpha10"
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM