[英]PyTorch Detecto Model: tensor incompatibiliy in predicition for a pretrained model
嘗試訓練一個非常簡單的 model 並使用 pytorch 檢測器的以下代碼進行圖像預測:
from detecto import core, utils, visualize
dataset = core.Dataset('images/')
model = core.Model(['rect'])
model.fit(dataset)
modelName = 'model_weights_simpleRect.pth'
model.save(modelName)
image = utils.read_image('simple_image_to_test.jpg')
predictions = model.predict(image)
這導致以下 output:
Epoch 1 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00, 1.56it/s]
Epoch 2 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.80it/s]
Epoch 3 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.80it/s]
Epoch 4 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.79it/s]
Epoch 5 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.79it/s]
Epoch 6 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.80it/s]
Epoch 7 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.78it/s]
Epoch 8 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.80it/s]
Epoch 9 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.78it/s]
Epoch 10 of 10
Begin iterating over training dataset
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.80it/s]
Traceback (most recent call last):
File "train_simpleRect_and_predict.py", line 15, in <module>
predictions = model.predict(image)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/detecto/core.py", line 338, in predict
preds = self._get_raw_predictions(images)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/detecto/core.py", line 294, in _get_raw_predictions
preds = self._model(images)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py", line 52, in forward
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torchvision/models/detection/roi_heads.py", line 550, in forward
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torchvision/models/detection/roi_heads.py", line 474, in postprocess_detections
pred_boxes = self.box_coder.decode(box_regression, proposals)
File "/home/std/anaconda3/envs/dri/lib/python3.7/site-packages/torchvision/models/detection/_utils.py", line 168, in decode
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous
如何獲得有關 model 維度的更多詳細信息,在 model 中,張量不兼容發生的確切位置以及如何修復它?
添加。 信息:我對其他數據使用了相同的代碼並且它有效。
謝謝!
問題是 xml 描述文件中的圖像尺寸錯誤,對應於每個圖像。
我修復了 xml 文件,並且錯誤不再發生。
另一個問題 - 引起了類似的張量維度錯誤,是由以下語句引起的:
model = Model.load(modelName, ['rect'])
正確的版本是:
model = Model()
model.load(modelName, ['rect'])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.