简体   繁体   English

如何在 Keras 中获取模型的可训练参数的数量?

[英]How can I get the number of trainable parameters of a model in Keras?

I am setting trainable=False in all my layers, implemented through the Model API, but I want to verify whether that is working.我在所有层中设置trainable=False ,通过Model API 实现,但我想验证这是否有效。 model.count_params() returns the total number of parameters, but is there any way in which I can get the total number of trainable parameters, other than looking at the last few lines of model.summary() ? model.count_params()返回参数的总数,但是除了查看model.summary()的最后几行之外,有什么方法可以获得可训练参数的总数?

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

The above snippet can be discovered in the end of layer_utils.print_summary() definition, which summary() is calling.上面的片段可以在layer_utils.print_summary()定义的末尾发现,这是summary()调用。


Edit: more recent version of Keras has a helper function count_params() for this purpose:编辑:最新版本的count_params()有一个辅助函数count_params()用于此目的:

from keras.utils.layer_utils import count_params

trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)

For TensorFlow 2.0 :对于TensorFlow 2.0

import tensorflow.keras.backend as K

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

For tensorflow.keras this works for me.对于 tensorflow.keras 这对我有用。 Its from the tensorflow github code for the function print_layer_summary_with_connections() in layer_utils.py它来自 tensorflow github 代码,用于 layer_utils.py 中的函数 print_layer_summary_with_connections()

import numpy as np
from tensorflow.python.util import object_identity

def count_params(weights):
    return int(sum(np.prod(p.shape.as_list())
      for p in object_identity.ObjectIdentitySet(weights)))

if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
else:
    trainable_count = count_params(model.trainable_weights)

print (trainable_count)

另一种计算可训练参数的方法是:

model.count_params()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM