简体   繁体   中英

How to load Tensorflow model from internal storage?

I wanted to know if it's possible to store and read the trained.tflite model from the Android device's internal storage instead of the assets folder?

Below is the original code (which works) for loading the model from the assets folder.

private MappedByteBuffer loadLocalModelFile() throws IOException {
  AssetFileDescriptor fileDescriptor = getAssets().openFd(MODEL_PATH);
  FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  long startOffset = fileDescriptor.getStartOffset();
  long declaredLength = fileDescriptor.getDeclaredLength();

  FileChannel fileChannel = inputStream.getChannel();
  return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

Is there a way to load the model from internal memory instead and still get the startOffset and declaredLength for fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)? If not, is there a way to calculate the startOffset of a new model and its declared length when reading the raw binary from internal storage?

I tried using the openNonAssetFd() function from AssetManager to obtain an AssetFileDescriptor for my file located in internal storage.

private MappedByteBuffer loadOnlineModelFile() throws IOException {
    FileInputStream inputStream = openFileInput(MODEL);

    AssetManager manager = getAssets();
    AssetFileDescriptor afd = manager.openNonAssetFd(getFilesDir() + "/graph.lite");

    long startOffset = afd.getStartOffset();
    long declaredLength = afd.getDeclaredLength();

    FileChannel fileChannel = inputStream.getChannel();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }

However, this will result in "java.lang.IllegalArgumentException: Model ByteBuffer should be either a MappedByteBuffer of the model file or a direct ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content" and "java.io.FileNotFoundException".

Well, I've been searching all over the place and finally I've figured it out. It's dead simple.
For some reason I thought that AssetFileDescriptor 's getStartOffset is related to the actual tflite model but it's not. I think the getStartOffset gives the start point of the file in the application's asset. And for the tflite model the startOffset should be 0 because that's where the file start as it is only one file. So, the code should be

File file = new File('path_to_model');
FileInputStream is = new FileInputStream(file);

return is.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length());

You can directly access the file from your internal storage. Here is a demo code to read a tflite model named model.tflite from sample folder which is located in the internal storage.

 @NonNull
  public MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) throws IOException {
    SupportPreconditions.checkNotNull(context, "Context should not be null.");
    SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
    File file = new File(Environment.getExternalStorageDirectory() + "/sample/" + filePath);

    MappedByteBuffer var9;
    try {
      FileInputStream inputStream = new FileInputStream(file);
      try {
        FileChannel fileChannel = inputStream.getChannel();
        var9 = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length());
      } catch (Throwable var12) {
        try {
          inputStream.close();
        } catch (Throwable var11) {
          var12.addSuppressed(var11);
        }
        throw var12;
      }

      inputStream.close();
    } catch (Throwable var13) {
      throw var13;
    }

    return var9;
  }

The filepath will be the name of the model. here it is model.tflite. We can call the method like this,

loadMappedFile(Classifier.this, "model.tflite");
    Combining all the answer over stackoverflow i found a solution 
    which is working in my android application.

    InputStream inputStream = null;
    try {
        inputStream = context.getAssets().open("model.tflite");
    } catch (IOException e) {
        e.printStackTrace();
    }
    byte[] buffer = new byte[8192];
    int bytesRead;
    ByteArrayOutputStream output = new ByteArrayOutputStream();
    try {
        while ((bytesRead = inputStream.read(buffer)) != -1) {
            output.write(buffer, 0, bytesRead);
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
    byte file[] = output.toByteArray();
    ByteBuffer bb = ByteBuffer.allocateDirect(file.length);
    bb.order(ByteOrder.nativeOrder());
    bb.put(file);
    imageClassifier = ImageClassifier.createFromBuffer(bb );

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM