简体   繁体   English

在tensorflow对象检测API中,是否有办法知道对象检测模型有多少个参数?

[英]Is there a way to know how many parameters does an object detection model have, in tensorflow object detection API?

I train different models with tensor object detection (TFOD) API and I would like to know how many parameters are trained for a given model. 我使用张量对象检测(TFOD)API训练了不同的模型,并且我想知道为给定模型训练了多少参数。

I run faster RCNN, SSD, RFCN and also with different image resolution, I would like to have a way to know how many parameters are trained. 我运行的RCNN,SSD,RFCN速度更快,并且图像分辨率也不同,我想有一种方法可以知道要训练多少个参数。 Is there a way to do that? 有没有办法做到这一点?

I have tried answers found here How to count total number of trainable parameters in a tensorflow model? 我已经尝试过在这里找到答案了如何计算张量流模型中可训练参数的总数? with no luck. 没有运气。

Here is the code I added line 103 of model_main.py : 这是我在model_main.py第103行添加的代码:

print("Training {} parameters".format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))

I think the problem is that I do not access the tf.Session() the TFOD is running, hence my code always return 0.0 parameters (although training strats just fine and train, hopefully, millions of parameters), and I don't know how to solve that issue. 我认为问题是我没有访问TFOD正在运行的tf.Session(),因此我的代码始终返回0.0参数(尽管训练策略很好,并且希望可以训练数百万个参数),但我不知道如何解决这个问题。

TFOD API used tf.estimator.Estimator to train and evaluate. TFOD API使用tf.estimator.Estimator进行训练和评估。 The Estimator object provided function to get all the variables, Estimator.get_variable_names() ( reference ). Estimator对象提供了获取所有变量Estimator.get_variable_names()参考 )。

You can add this line print(estimator.get_variable_names()) after estimator.train_and_evaluate() ( here ). 您可以在print(estimator.get_variable_names())之后添加此行print(estimator.get_variable_names()) estimator.train_and_evaluate()在此处 )。

You will see all variable names printed after the training is completed. 培训完成后,您将看到所有已打印的变量名称。 To see the results faster, you can train for just 1 step. 要更快地查看结果,您只需训练1个步骤。

When using the export_inference_graph.py, the script also analyzes your model, and counts parameters and FLOPS (if possible). 使用export_inference_graph.py时,脚本还会分析您的模型,并计算参数和FLOPS(如果可能)。 If looks like this: 如果看起来像这样:

_TFProfRoot (--/# total params)
  FeatureExtractor (--/# params)
  ...
  WeightSharedConvolutionalBoxPredictor (--/# params)
  ...

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 我想知道正确的 Tensorflow object 检测-api Z20F35E630DAF44DBFA4C3F68F539D 检测植物病害吗? - I want to know the right Tensorflow object detection- api model for plant disease detection and weed detection? 有没有办法限制现有的 TensorFlow 对象检测模型? - Is there a way to limit an existing TensorFlow object detection model? TensorFlow 2 Object 检测 API ZA559B87068921EEC05086CE5485E978 评估 - TensorFlow 2 Object Detection API Model Evaluation 如何使用Tensorflow对象检测API提高对象检测的精度? - How to improve precision of object detection using tensorflow object detection API? 在tensorflow对象检测API中进行训练时改组如何工作 - How does shuffling work while training in tensorflow object detection API 具有奇怪检测结果的Tensorflow对象检测API - Tensorflow object detection api with weird detection result 条件检测使用 Tensorflow Object 检测 API - Conditional Detection using Tensorflow Object Detection API Tensorflow 对象检测 API:如何提高图像的检测分数 - Tensorflow Object Detection API : How to improve detection_scores of an image 如何计算 Tensorflow Object Detection API 中的对象 - How to count objects in Tensorflow Object Detection API 如何将mlflow与tensorflow object检测api集成 - How to integrate mlflow with tensorflow object detection api
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM