簡體   English   中英

嘗試在 Android 工作室中實現 tensorflow lite model 時,我收到不兼容的數據類型錯誤

[英]When trying to implement tensorflow lite model in Android studio I am getting an incompatible data types error

我使用 Z063009BB15C8272BD0C701CFDF01126 在 python 中構建了一個簡單的 tensorflow 多類分類 model。 我現在正試圖在 Android 工作室中運行 model 但在 Android 工作室中出現以下錯誤。

java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite tensor with type INT64 and a Java object of type [[F (which is compatible with the TensorFlowLite type FLOAT32).

我將首先粘貼用於創建 model 的 google colab notebook 代碼。 本質上,有兩個輸入變量(2 個整數),然后 model 根據這兩個輸入變量對獲得 171 個不同類別的概率進行分類。 然后,我將包含 android 工作室中使用的 java 代碼,以嘗試運行 tflite model。

這是用於構建分類 model 的 python 代碼

import io
df = pd.read_csv(io.BytesIO(data_to_load['species_by_location_v4.csv']))

# CREATE X ARRAY
# This is the array containing the explanatory variables (in this case pentad and month)

loc_array = df.iloc[:, 1:3]
print(loc_array)

# create y array (classes to be predicted)

label_array = df.iloc[:, 4]

# get number of distinct classes and convert y array to consecutive integers from 0 to 170 (y_true)

raw_y_true = label_array
mapping_to_numbers = {}
y_true = np.zeros((len(raw_y_true)))
for i, raw_label in enumerate(raw_y_true):
    if raw_label not in mapping_to_numbers:
        mapping_to_numbers[raw_label] = len(mapping_to_numbers)
    y_true[i] = mapping_to_numbers[raw_label]
print(y_true)
# [0. 1. 2. 3. 1. 2.]
print(mapping_to_numbers)

# get number of distinct classes

num_classes = len(mapping_to_numbers)
print(num_classes)


# create simple model

model = tf.keras.Sequential([
  tf.keras.layers.Dense(num_classes, activation='softmax')
])

# compile model

model.compile(
    optimizer = 'adam',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy']
)

# train model

history = model.fit(loc_array, y_true, epochs=500, verbose=False)
print('finished')

# create labels file

labels = '\n'.join(mapping_to_numbers.keys())

with open('labels_locmth.txt', 'w') as f:
  f.write(labels)

!cat labels.txt

# convert to tflite

saved_model_dir = 'save/fine_tuning'
tf.saved_model.save(model, saved_model_dir)

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

with open('model_locmth.tflite', 'wb') as f:
  f.write(tflite_model)


在 python 中,我可以使用以下代碼使用 model 運行預測。 它只需要 2 個 integer 值作為輸入變量,在本例中為 63669 和 2。

temp = {'pentad_unique_key': [63669], 'mth': [2] }
test_x = pd.DataFrame(temp, columns = ['pentad_unique_key', 'mth'])
result = model.predict(test_x, batch_size=None, verbose=0, steps=None)

現在在 Android 工作室中,我使用以下代碼嘗試運行 model。 我試圖傳遞與上面 python 代碼中相同的兩個整數。 在這里的最后一行,我得到了上面的錯誤。

float[][] inputVal = new float[1][2];
inputVal[0][0] = 63669;
inputVal[0][1] = 2;

float[][] outputs = new float[1][171];

tflite.run(inputVal, outputs);

這就是我構建 tflite object 的方式

try{
   tflite = new Interpreter(loadModelFile());
   labelList = loadLabelList();
} catch (Exception ex) {
   ex.printStackTrace();
}

這是 loadModelFile 方法

private MappedByteBuffer loadModelFile() throws IOException {

    // Open the model using an input stream and memory map it to load
    AssetFileDescriptor fileDescriptor = this.getAssets().openFd("model_locmth.tflite");
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);

}

我感覺問題與 inputVal 數組有關?

如前所述,我認為問題出在y_true的輸入 dtype 上。 當您創建y_true

y_true = np.zeros((len(raw_y_true)))

它會將 dtype 顯示為 np.int64,因為那是默認的 dtype。
tflite 模型可以與錯誤中提到的 float32(不是 int64 或 float64)一起使用,因此您需要將 y_true 定義為

y_true = np.zeros((len(raw_y_true)),dtype=np.float32)

所以,我只更改了以下代碼部分中的上述行

label_array = df.iloc[:, 4]

raw_y_true = label_array
mapping_to_numbers = {}
y_true = np.zeros((len(raw_y_true)),dtype=np.float32)
for i, raw_label in enumerate(raw_y_true):
    if raw_label not in mapping_to_numbers:
        mapping_to_numbers[raw_label] = len(mapping_to_numbers)
    y_true[i] = mapping_to_numbers[raw_label]
print(y_true)
# [0. 1. 2. 3. 1. 2.]
print(mapping_to_numbers)

print(np.argmax(result)) # 122

一個建議是盡可能在代碼中使用 tensorflow 操作。 例如,不要使用np.zeros ,而是使用tf.zeros等。

早些時候,我在使用 float64 的簡單 model 中遇到了類似的錯誤。 model 的完整代碼和您可以在此處看到的錯誤。 希望能幫助到你。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM