简体   繁体   English

如何正确保存 keras 模型以便能够使用 hub.Module() 加载?

[英]How to correctly save keras model to be able to load with hub.Module()?

I am attempting to retrain inception v3 on a new image set.我正在尝试在新图像集上重新训练 inception v3。

When I try to save the model I receive an error.当我尝试保存模型时收到错误消息。

I have tried:我试过了:

    tf.keras.models.save_model(model, filename)

and

    model.save(filename)

and

    tf.contrib.saved_model.save_keras_model(model, filename)       

All give me a similar error, Module has no ' name '都给我一个类似的错误,Module has no ' name '

I have attached my code relevant to the problem.我附上了与该问题相关的代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt

FLAGS = None
def create_model(m, img_data):
    # load feature extractor (inception_v3)
    features_extractor_layer = tf.keras.layers.Lambda(m, input_shape=img_data.image_shape)

    # make pre-trained layers un-trainable
    features_extractor_layer.trainable = False

    print(features_extractor_layer.name)

    # add new activation layer to train to our classes
    model = tf.keras.Sequential([
        features_extractor_layer,
        tf.keras.layers.Dense(img_data.num_classes, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    return model
def get_and_gen_images(module):
    """
    get images from image directory or url

    :param module: module (to get required image size info
    :return: batched image data
    """
    data_name = os.path.splitext(os.path.basename(FLAGS.image_dir_or_url))[0]
    print("data: ", data_name)

    # download images to cache if not already
    if FLAGS.image_dir_or_url.startswith('https://'):
        data_root = tf.keras.utils.get_file(data_name,
                                            FLAGS.image_dir_or_url,
                                            untar=True,
                                            cache_dir=os.getcwd())
    else:   # specify directory with images
        data_root = tf.keras.utils.get_file(data_name,
                                            FLAGS.image_dir_or_url)

    # get image size for specific module
    image_size = hub.get_expected_image_size(module)


    # TODO: this is where to add noise, rotations, shifts, etc.
    image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, validation_split=0.2)

    # create image stream
    train_image_data = image_generator.flow_from_directory(str(data_root),
                                                           target_size=image_size,
                                                           batch_size=FLAGS.batch_size,
                                                           subset='training')

    validation_image_data = image_generator.flow_from_directory(str(data_root),
                                                                target_size=image_size,
                                                                batch_size=FLAGS.batch_size,
                                                                subset='validation')

    return train_image_data, validation_image_data
# load module (will download from url or directory_
module = hub.Module(FLAGS.tfhub_module)

# generate image stream
train_image_data, validation_image_data = get_and_gen_images(module)

model = create_model(module, train_image_data)
model.summary()

file = FLAGS.saved_model_dir + "/modelname.h5"

model.save(file)

This should save a ".h5" model file, but I receive a naming error:这应该保存一个“.h5”模型文件,但我收到一个命名错误:

Traceback (most recent call last):
  File "/home/raphy/projects/vmi/tf_cpu/retrain.py", line 305, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/home/raphy/projects/vmi/tf_cpu/retrain.py", line 205, in main
    model.save(file)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 319, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 105, in save_model
    'config': model.get_config()
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 326, in get_config
    'config': layer.get_config()
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py", line 756, in get_config
    function = self.function.__name__
AttributeError: 'Module' object has no attribute '__name__'

I want to save the model in the format of the tf_hub models.我想以 tf_hub 模型的格式保存模型。

As specified in Hosting TF Hub TF Org Link ,托管 TF Hub TF Org Link中所述,

If you are interested in hosting your own repository of models that are loadable with the tensorflow_hub library, your HTTP distribution service should follow the following protocol.如果您有兴趣托管自己的可使用 tensorflow_hub 库加载的模型存储库,您的 HTTP 分发服务应遵循以下协议。

In other words, you cannot load any Model using TF Hub, but you can load only the Modules present in the TF Hub Modules Site .换句话说,您不能使用 TF Hub 加载任何模型,但您只能加载TF Hub 模块站点中存在的模块。

If you want to Load your Saved Model, you can do it using tf.saved_model.load .如果你想加载你保存的模型,你可以使用tf.saved_model.load来完成。

But if you want to do it using TF Hub, please refer this link .但如果您想使用 TF Hub 进行操作,请参考此链接

Also, mentioning the instructions below just in case if the link doesn't work:另外,如果链接不起作用,请提及以下说明:

Hosting your own models :托管您自己的模型

TensorFlow Hub provides an open repository of trained models at thub.dev . TensorFlow Hub 在thub.dev上提供了一个开放的训练模型存储库。 The tensorflow_hub library can load models from this repository and other HTTP based repositories of machine learning models. tensorflow_hub库可以从这个存储库和其他基于 HTTP 的机器学习模型存储库加载模型。 In particular the protocol allows to use the URL identifying the model both for the documentation of the model and the endpoint to fetch the model.特别是,该协议允许使用识别模型的 URL 来获取模型的文档和端点以获取模型。

If you are interested in hosting your own repository of models that are loadable with the tensorflow_hub library, your HTTP distribution service should follow the following protocol.如果您有兴趣托管自己的可使用tensorflow_hub库加载的模型存储库,您的 HTTP 分发服务应遵循以下协议。

Protocol:协议:

When a URL such as https://example.com/model is used to identify a model to load or instantiate, the model resolver will attempt to download a compressed tarball from the URL after appending a query parameter?当使用诸如https://example.com/model的 URL 来标识要加载或实例化的模型时,模型解析器将在附加查询参数后尝试从 URL 下载压缩的 tarball? tf-hub-format=compressed . tf-hub-format=compressed

The query param is to be interpreted as a comma separated list of the model formats that the client is interested in. For now only the "compressed" format is defined.查询参数将被解释为客户端感兴趣的模型格式的逗号分隔列表。目前仅定义了“压缩”格式。

The compressed format indicates that the client expects a tar.gz archive with the model contents.压缩格式表示客户端需要包含模型内容的tar.gz存档。 The root of the archive is the root of the model directory and should contain a SavedModel, as in this example:存档的根是模型目录的根,应该包含一个 SavedModel,如本例所示:

# Create a compressed model from a SavedModel directory.
$ tar -cz -f model.tar.gz --owner=0 --group=0 -C /tmp/export-model/ .

# Inspect files inside a compressed model
$ tar -tf model.tar.gz
./
./variables/
./variables/variables.data-00000-of-00001
./variables/variables.index
./assets/
./saved_model.pb

Tarballs for use with the deprecated hub.Module() API from TF1 will also contain a./tfhub_module.pb file.与 TF1 中已弃用的hub.Module() API 一起使用的 Tarball 还将包含一个 ./tfhub_module.pb 文件。 The hub.load() API for TF2 SavedModels ignores such a file. TF2 SavedModels 的 hub.load() API 会忽略此类文件。

The tensorflow_hub library expects that model URLs are versioned and that the model content of a given version is immutable, so that it can be cached indefinitely. tensorflow_hub 库期望模型 URL 是版本化的,并且给定版本的模型内容是不可变的,以便可以无限期地缓存。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM