![](/img/trans.png)
[英]Tensorflow Keras: Dimension/Shape Error when running model.fit
[英]NotImplementedError when running model.fit in Tensorflow Keras
我目前正在研究 Aeriel 視頻上的人類行為識別。 我正在使用這個數據集。 您可以看到視頻和標簽文件。 我正在構建一個 SSD model 來訓練數據..
使用 model.fit 時出現錯誤。
我認為主要問題出在 DataGenerator Class 雖然我無法解決錯誤。 代碼如下
import numpy as np
import cv2
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import os
import json
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, img_dir, ann_dir,
batch_size=32, dim=(300,300),
shuffle=True):
self.idx_to_name = ["None", '"Handshaking"\n', '"Hugging"\n', '"Reading"\n', '"Drinking"\n', '"Pushing/Pulling"\n', '"Carrying"\n', '"Calling"\n', '"Running"\n', '"Walking"\n', '"Lying"\n', '"Sitting"\n', '"Standing"\n']
self.name_to_idx = dict([(v, k) for k, v in enumerate(self.idx_to_name)])
self.img_dir = img_dir
self.ann_dir = ann_dir
# self.frame = frame
self.batch_size = batch_size
self.dim = dim
def _get_annotation(self, file, j):
frame_map = dict()
with open(file, 'r') as fp:
line = fp.readline()
while line:
line_split = line.split(' ')
frame_id = int(line_split[5])
if line_split[10] is not None:
label = line_split[10]
else:
label = "None"
val = (int(line_split[0]), list(map(int, line_split[1:5])), list(map(int, line_split[6:8])), line_split[10])
if frame_id not in frame_map:
frame_map[frame_id] = [val]
else:
frame_map[frame_id].append(val)
line = fp.readline()
for obj in frame_map[int(j)]:
xmin = float(obj[1][0]) / 3840.0
ymin = float(obj[1][1]) / 2160.0
xmax = float(obj[1][2]) / 3840.0
ymax = float(obj[1][3]) / 2160.0
name = obj[3]
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.name_to_idx[name] + 1)
return np.array(boxes, dtype=np.float32), np.array(labels, dtype=np.int64)
def __getitem__(self, index):
start_index = index * self.batch_size
x_train = []
y_train = []
i = start_index - 1
while len(x_train) < self.batch_size:
try:
for i in os.listdir(self.img_dir):
for j in os.listdir(self.img_dir + '/' + i):
img = cv2.imread(self.img_dir + '/' +i + '/' + j)
img = cv2.resize(img,(320,240))
img = np.array(img, dtype = np.float32)
img = img / 255.0
boxes, labels = self._get_annotation(self.ann_dir + '/' + i + '.txt', int(j[:-4]))
x_train.append(img)
y_train.append(boxes)
i += 1
except Exception as err:
print(err)
continue
x_train = np.array(x_train)
y_train = np.array(y_train)
return x_train, y_train
train_data = DataGenerator("/content/okutama_imgs", "/content/okutama_labels", batch_size=4)
model.compile(optimizer=optimizer, loss= SSD_loss, metrics=['accuracy'])
model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)
運行 model.fit 時會出現此錯誤。 我不明白為什么會這樣,如果您需要更多信息,我很樂意提供給您。 這是錯誤
NotImplementedError Traceback (most recent call last)
<ipython-input-64-7c278c6b3232> in <module>()
----> 1 model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)
4 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/data_utils.py in __len__(self)
456 The number of batches in the Sequence.
457 """
--> 458 raise NotImplementedError
459
460 def on_epoch_end(self):
NotImplementedError:
此錯誤表明您沒有在繼承的 class 中實現一些需要的 function,在本例中為__len__
function。
像這樣將__len__
function 添加到DataGenerator
class 中:
def __len__(self):
#len(self.x) is the length of your input features
return math.ceil(len(self.x) / self.batch_size)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.