简体   繁体   中英

How to find the number of layers from the saved h5 file of pre-trained model?

I have VGG-16 weights saved in h5 format. I want to see the number of layers in the network. How can I do that?

I tried using:

file = h5py.File(vgg16.h5)

after that, I checked for file.attrs but after this point I don't know which command to use to find the number of layers in the network.

If you type this code in spyder provided, you have installed Keras library. You will get all the layers in h5.file

from keras.models import load_model

classifier=load_model('my_model.h5')

classifier.summary()

In case you have only weights and not the model structure, you can use keys() method to get all the layers name in that weight file.

For example: I have one weight file: saved-weight.h5

If I want to know what are the layers present in this weight file, you can do the following:

import h5py
file = h5py.File('saved-weight.h5')

layer_names = file.keys()

# output 
layer_names =  <KeysViewHDF5 ['add', 'bn_3', 'bn_5', 'bn_7', 
    'concatenate', 'conv_1', 'conv_2', 'conv_3',
    'dropout', 'fc_8', 'fc_9', 
    'gru_10', 'gru_10_back', 'gru_11', 'gru_11_back', 
    'input_1', 'input_3', 'input_4', 'labels', 
    'lambda', 'lambda_1', 'lambda_2', 'lambda_3', 
    'maxpool_3', 'maxpool_5', 'model', 'permute', 'reshape']>

These are the layers present in the saved-weights file

Without your data or code, it is hard to provide more details. To demonstrate h5py methods to access h5 data, here is a simple example that creates a h5 file with 1 group with 3 datasets. After the group and datasets are created, there is a loop to print the dataset name, shape and dtype.

import h5py, numpy as np
h5f=h5py.File('SO_54511719.h5','w')

ds_data = np.random.random(100).reshape(10,10)
group1 = h5f.create_group('group1')
group1.create_dataset('ds_1', data=ds_data)
group1.create_dataset('ds_2', data=ds_data)
group1.create_dataset('ds_3', data=ds_data)

print ('number of datasets in group:', len(group1))
for (dsname, dsvalue) in group1.items() :
    print ('for',dsname,':')
    print ('shape:',dsvalue.shape)
    print ('dtype:',dsvalue.dtype)

h5f.close()

Output looks like this:

number of datasets in group: 3
for ds_1 :
shape: (10, 10)
dtype: float64
for ds_2 :
shape: (10, 10)
dtype: float64
for ds_3 :
shape: (10, 10)
dtype: float64

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM