[英]Exporting keras model into tflite
I am trying to combine this two examples and create the tflite file for my android app.我正在尝试结合这两个示例并为我的 android 应用程序创建 tflite 文件。
https://medium.com/nybles/create-your-first-image-recognition-classifier-using-cnn-keras-and-tensorflow-backend-6eaab98d14dd https://medium.com/nybles/create-your-first-image-recognition-classifier-using-cnn-keras-and-tensorflow-backend-6eaab98d14dd
https://medium.com/@xianbao.qian/convert-keras-model-to-tflite-e2bdf28ee2d2 https://medium.com/@xianbao.qian/convert-keras-model-to-tflite-e2bdf28ee2d2
This is my code:这是我的代码:
# Part 1 - Building the CNN
# Importing the Keras libraries and packages
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
import tensorflow as tf
from keras.models import load_model
# Initialising the CNN
classifier = Sequential()
# Step 1 - Convolution
classifier.add(Convolution2D(32, 3, 3, input_shape = (64, 64, 3), activation = 'relu'))
# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Adding a second convolutional layer
classifier.add(Convolution2D(32, 3, 3, activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 3 - Flattening
classifier.add(Flatten())
# Step 4 - Full connection
classifier.add(Dense(output_dim = 128, activation = 'relu'))
classifier.add(Dense(output_dim = 1, activation = 'sigmoid'))
# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
# Part 2 - Fitting the CNN to the images
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('dataset/training_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')
test_set = test_datagen.flow_from_directory('dataset/test_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')
classifier.fit_generator(training_set,
samples_per_epoch = 80,
nb_epoch = 1,
validation_data = test_set,
nb_val_samples = 20)
output_names = [node.op.name for node in classifier.outputs]
sess = tf.keras.backend.get_session()
frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)
tflite_model = tf.contrib.lite.toco_convert(frozen_def, [inputs], output_names)
with tf.gfile.GFile(tflite_graph, 'wb') as f:
f.write(tflite_model)
And at this line:在这一行:
frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)
I got an exception:我有一个例外:
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value conv2d_1/bias
[[Node: _retval_conv2d_1/bias_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/device:CPU:0"](conv2d_1/bias)]]
I am a beginner in the machine learning and absolutely have no idea what is this error about :-(我是机器学习的初学者,完全不知道这个错误是什么:-(
Can somebody explain to me what is wrong?有人可以向我解释什么是错的吗? All that I need is processed several folders with many pics and make it possible to predict the relations of the new coming pictures to the certain folder.
我所需要的只是处理多个包含许多图片的文件夹,并可以预测新图片与特定文件夹的关系。 Thank you.
谢谢你。
It is possible to directly convert a keras-model to .tflite
using the tf.lite.TFLiteConverter.from_session
function.可以使用
tf.lite.TFLiteConverter.from_session
函数将 keras 模型直接转换为.tflite
。 Place the following code after fit_generator
to export it (tested with tensorflow 1.3.1)将下面的代码
fit_generator
之后导出(用tensorflow 1.3.1测试)
with tf.keras.backend.get_session() as sess:
sess.run(tf.global_variables_initializer())
converter = tf.lite.TFLiteConverter.from_session(sess, model.inputs, model.outputs)
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
f.write(tflite_model)
A bit late to the party but here's how you do it:参加聚会有点晚了,但您可以这样做:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
Source: https://www.tensorflow.org/lite/convert/python_api来源: https : //www.tensorflow.org/lite/convert/python_api
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.