繁体   English   中英

如何在没有模型的情况下读取keras模型权重

[英]How to read keras model weights without a model

keras模型可以保存在两个文件中。 一个文件具有模型体系结构。 另一个是模型权重,权重由方法model.save_weights()保存。

然后可以使用model.load_weights(file_path)加载权重。 它假定模型存在。

我只需要加载没有模型的权重。 我试着用pickle.load()

with open(file_path, 'rb') as fp:
    w = pickle.load(fp)

但它给出了错误:

_pickle.UnpicklingError: invalid load key, 'H'.

我认为权重文件以不兼容的方式保存。 是否可以从model.save_weights()创建的文件中仅加载权重?

数据格式为h5,因此您可以直接使用h5py库来检查和加载权重。 快速入门指南

import h5py
f = h5py.File('weights.h5', 'r')
print(list(f.keys())
# will get a list of layer names which you can use as index
d = f['dense']['dense_1']['kernel:0']
# <HDF5 dataset "kernel:0": shape (128, 1), type "<f4">
d.shape == (128, 1)
d[0] == array([-0.14390108], dtype=float32)
# etc.

该文件包含属性,包括图层的权重,您可以详细探索存储的内容和方式。 如果你想要一个可视版本,那么也有h5pyViewer

参考: https//github.com/keras-team/keras/issues/91下面的问题代码片段

from __future__ import print_function

import h5py

def print_structure(weight_file_path):
    """
    Prints out the structure of HDF5 file.

    Args:
      weight_file_path (str) : Path to the file to analyze
    """
    f = h5py.File(weight_file_path)
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(weight_file_path))
            print("Root attributes:")

        print("  f.attrs.items(): ")
        for key, value in f.attrs.items():           
            print("  {}: {}".format(key, value))

        if len(f.items())==0:
            print("  Terminate # len(f.items())==0: ")
            return 

        print("  layer, g in f.items():")
        for layer, g in f.items():            
            print("  {}".format(layer))
            print("    g.attrs.items(): Attributes:")
            for key, value in g.attrs.items():
                print("      {}: {}".format(key, value))

            print("    Dataset:")
            for p_name in g.keys():
                param = g[p_name]
                subkeys = param.keys()
                print("    Dataset: param.keys():")
                for k_name in param.keys():
                    print("      {}/{}: {}".format(p_name, k_name, param.get(k_name)[:]))
    finally:
        f.close()
print_structure('weights.h5.keras')

您需要创建一个Keras Model ,然后您可以加载您的architecture ,然后加载model weights

请参阅下面的代码,

model = keras.models.Sequential()          # create a Keras Model
model.load_weights('my_model_weights.h5')  # load model weights

Keras文档中的更多信息

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM