繁体   English   中英

TensorFlow将模型导出转换保存为tflite

[英]TensorFlow saved model export conversion to tflite

TLDR :运行时出现ValueError:

tf.contrib.lite.TocoConverter.from_saved_model()

目的 :我正在尝试将TensorFlow保存的模型转换为tflite,以便通过Firebase部署在移动设备上。 我可以训练模型并输出保存的模型,但是在使用python ToCo接口将其转换为.tflite时遇到了麻烦。 任何帮助将不胜感激。 另外,是否有人可以评论tflite转换是否将捕获我所依赖的hub.text_embedding_column()输入过程。 移动部署将使用原始输入文本执行此操作,还是需要单独部署其中的那一部分?

问题 :这是我正在运行的代码:

输入:

train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df["target_var"], num_epochs=None, shuffle=True
)

predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df["target_var"], shuffle=False
)

predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
    test_df, test_df["target_var"], shuffle=False)

embedded_text_feature_column = hub.text_embedding_column(
    key="text", 
    module_spec="https://tfhub.dev/google/nnlm-en-dim128/1"
)

培训和评估:

estimator = tf.estimator.DNNClassifier(
    hidden_units=[500, 100],
    feature_columns=[embedded_text_feature_column],
    n_classes=2,
    optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
    model_dir="my-model"
)

estimator.train(input_fn=train_input_fn, steps=1000)

train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)

保存模式:

feature_spec = tf.feature_column.make_parse_example_spec([embedded_text_feature_column])

serve_input_fun = tf.estimator.export.build_parsing_serving_input_receiver_fn(
    feature_spec,
    default_batch_size=None
)

estimator.export_savedmodel(
    export_dir_base = "my-model",
    serving_input_receiver_fn = serve_input_fun,
    as_text=False,
    checkpoint_path="my-model/model.ckpt-1000",
)

转换模型:

converter = tf.contrib.lite.TocoConverter.from_saved_model("my-model/1529320265/") 
tflite_model = converter.convert()

错误

运行最后一行时,出现以下错误:

ValueError:张量input_example_tensor:0未知类型tf.string

完整的跟踪是:

ValueError跟踪(最近一次通话)
在()中
1个转换器= tf.contrib.lite.TocoConverter.from_saved_model(“ my-model / 1529320265 /”)
----> 2 tflite_model = converter.convert()

/media/rmn/data/projects/anaconda3/envs/monily_tf19/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py在convert(self)中
307 reorder_across_fake_quant = self.reorder_across_fake_quant,
第308章席卷天元
-> 309 allow_custom_ops = self.allow_custom_ops)
310返回结果
311
/media/rmn/data/projects/anaconda3/envs/monily_tf19/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py in toco_convert(输入数据,输入张量,输出张量,推断类型,推断输入类型,输入格式,output_format,quantized_input_stats,default_ranges_stats,drop_control_dependency,reorder_across_fake_quant,allow_custom_ops,change_concat_input_ranges)
204其他:
205提高ValueError(“张量%s未知类型%r”%(input_tensor.name,-> 206 input_tensor.dtype))
207
第208章

ValueError:张量input_example_tensor:0未知类型tf.string

细节

train_dftest_df是由单个输入文本列和二进制目标变量组成的pandas数据帧。 我正在使用Python 3.6.5和TensorFlow r1.9。

此问题已在TensorFlow的master分支上修复(在commit d3931c8中 )。 请参考TensorFlow网站上的以下文档以从GitHub构建pip安装: https : //www.tensorflow.org/install/install_sources

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM