简体   繁体   English

Tensorflow对象检测API:如何禁用从检查点加载

[英]Tensorflow object detection API: How to disable loading from checkpoint

I have created a custom variation of MobileNetV2 feature extractor architecture, by changing the expansion_size from 6 to 4 in research/slim/nets/mobilenet/mobilenet_v2.py of tensorflow/models repo. 我创建MobileNetV2特征提取架构的自定义变化,通过改变expansion_size从6比4 research/slim/nets/mobilenet/mobilenet_v2.pytensorflow/models回购。

I want to be able to train the SSD + Mobilenet_v2 (with this change) model with model_main.py script as described at Object Detection API's running_locally tutorial . 我希望能够使用model_main.py脚本训练SSD + Mobilenet_v2(具有此更改)模型,如对象检测API的running_locally教程中所述

When I do so I see the following error, which makes sense: 当我这样做时,我看到以下错误,这是有道理的:

`InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

To address this: 要解决此问题:

  1. I removed the finetune_checkpoint specification from my pipeline.config . 我从pipeline.config删除了finetune_checkpoint规范。
  2. I changed load_pretrained=True to load_pretrained=False in object_detection/model_hparams.py . 我在object_detection/model_hparams.py load_pretrained=True更改为load_pretrained=False
  3. I added --hparams_overrides='load_pretrained=false' as a command line input argument to model_main.py . 我添加--hparams_overrides='load_pretrained=false'作为命令行输入参数model_main.py

Despite of these, I still see the same error. 尽管有这些,我仍然看到相同的错误。

Why is tensorflow still trying to restore a checkpoint. 为什么tensorflow仍在尝试还原检查点。 How can I make it not do so? 我该如何做呢?

Found the solution myself. 自己找到解决方案。 It turns out that even though I had made arrangements for it to not restore checkpoint from my pipeline configuration file, it turns out that the internal tf.Estimator object can still use a checkpoint from the model_dir specified; 事实证明,即使我已安排它不从管道配置文件中恢复检查点,但事实证明内部tf.Estimator对象仍可以使用指定的model_dir的检查点; even though the primary use of model_dir is as an output directory, where output checkpoints are written to. 即使model_dir的主要用途是将输出检查点写入其中的输出目录。

I found this information in the official documentation for tf.Estimator . 我在tf.Estimator官方文档中找到了此信息。 Here's the relevant excerpt for reference: 以下是相关摘录供参考:

`model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into an estimator to continue training a previously saved model . `model_dir:用于保存模型参数,图形等的目录这也可用于将检查点从目录加载到估计器中,以继续训练先前保存的模型 If PathLike object, the path will be resolved. 如果是PathLike对象,则路径将被解析。 If None, the model_dir in config will be used if set. 如果设置为None,则使用config中的model_dir。 If both are set, they must be same. 如果两者都设置,则必须相同。 If both are None, a temporary directory will be used. 如果两者都为None,则将使用一个临时目录。

I had an old checkpoint sitting in my original model_dir which was architecturally incompatible with my custom model. 我原来的model_dir有一个旧的检查点,该检查点在架构上与我的自定义模型不兼容。 Hence I was seeing the error. 因此,我看到了错误。 To resolve it, I simply changed my model_dir to point to a new empty directory and it worked. 为了解决这个问题,我只需将model_dir更改为指向一个新的空目录即可。 I hope that helps someone with a similar problem. 我希望这对遇到类似问题的人有所帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Tensorflow 2.0 Object 检测训练错误 - 加载检查点时出错 - Tensorflow 2.0 Object Detection Training Error - Error with loading checkpoint 张量流对象检测从现有检查点微调模型 - tensorflow object detection Fine-tuning a model from an existing checkpoint 如何禁用自动检查点加载 - How to disable automatic checkpoint loading Tensorflow-GPU 对象检测 API 在第一次保存检查点后卡住 - Tensorflow-GPU Object Detection API gets stuck after first saved checkpoint 无法加载预训练的 model 检查点与 TensorFlow Object 检测 ZDB974238714CA8DE634A7ACE1 - Unable to load pre-trained model checkpoint with TensorFlow Object Detection API 了解 Tensorflow 对象检测 API,检查点的 kwargs class,什么是`_base_tower_layers_for_heads`? - Understanding Tensorflow Object-Detection API, kwargs for Checkpoint class, what is `_base_tower_layers_for_heads`? 如何在TensorFlow对象检测API中从头开始训练? - How to train from scratch in TensorFlow object detection API? 从 TensorFlow 对象检测 API 打印对象 - Print Objects from TensorFlow Object Detection API 如何使用Tensorflow对象检测API提高对象检测的精度? - How to improve precision of object detection using tensorflow object detection API? 如何计算 Tensorflow Object Detection API 中的对象 - How to count objects in Tensorflow Object Detection API
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM