[英]Is there a way to know how many parameters does an object detection model have, in tensorflow object detection API?
I train different models with tensor object detection (TFOD) API and I would like to know how many parameters are trained for a given model. 我使用张量对象检测(TFOD)API训练了不同的模型,并且我想知道为给定模型训练了多少参数。
I run faster RCNN, SSD, RFCN and also with different image resolution, I would like to have a way to know how many parameters are trained. 我运行的RCNN,SSD,RFCN速度更快,并且图像分辨率也不同,我想有一种方法可以知道要训练多少个参数。 Is there a way to do that?
有没有办法做到这一点?
I have tried answers found here How to count total number of trainable parameters in a tensorflow model? 我已经尝试过在这里找到答案了如何计算张量流模型中可训练参数的总数? with no luck.
没有运气。
Here is the code I added line 103 of model_main.py
: 这是我在
model_main.py
第103行添加的代码:
print("Training {} parameters".format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
I think the problem is that I do not access the tf.Session() the TFOD is running, hence my code always return 0.0 parameters (although training strats just fine and train, hopefully, millions of parameters), and I don't know how to solve that issue. 我认为问题是我没有访问TFOD正在运行的tf.Session(),因此我的代码始终返回0.0参数(尽管训练策略很好,并且希望可以训练数百万个参数),但我不知道如何解决这个问题。
TFOD API used tf.estimator.Estimator
to train and evaluate. TFOD API使用
tf.estimator.Estimator
进行训练和评估。 The Estimator
object provided function to get all the variables, Estimator.get_variable_names()
( reference ). Estimator
对象提供了获取所有变量Estimator.get_variable_names()
( 参考 )。
You can add this line print(estimator.get_variable_names())
after estimator.train_and_evaluate()
( here ). 您可以在
print(estimator.get_variable_names())
之后添加此行print(estimator.get_variable_names())
estimator.train_and_evaluate()
( 在此处 )。
You will see all variable names printed after the training is completed. 培训完成后,您将看到所有已打印的变量名称。 To see the results faster, you can train for just 1 step.
要更快地查看结果,您只需训练1个步骤。
When using the export_inference_graph.py, the script also analyzes your model, and counts parameters and FLOPS (if possible). 使用export_inference_graph.py时,脚本还会分析您的模型,并计算参数和FLOPS(如果可能)。 If looks like this:
如果看起来像这样:
_TFProfRoot (--/# total params)
FeatureExtractor (--/# params)
...
WeightSharedConvolutionalBoxPredictor (--/# params)
...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.