簡體   English   中英

NotImplementedError 運行 model.fit 時在 Tensorflow Keras

[英]NotImplementedError when running model.fit in Tensorflow Keras

我目前正在研究 Aeriel 視頻上的人類行為識別。 我正在使用這個數據集 您可以看到視頻標簽文件。 我正在構建一個 SSD model 來訓練數據..

使用 model.fit 時出現錯誤。

我認為主要問題出在 DataGenerator Class 雖然我無法解決錯誤。 代碼如下

import numpy as np
import cv2
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import os
import json


class DataGenerator(tf.keras.utils.Sequence):
   def __init__(self, img_dir, ann_dir, 
                batch_size=32, dim=(300,300),
                shuffle=True):
       self.idx_to_name = ["None", '"Handshaking"\n', '"Hugging"\n', '"Reading"\n', '"Drinking"\n', '"Pushing/Pulling"\n', '"Carrying"\n', '"Calling"\n', '"Running"\n', '"Walking"\n', '"Lying"\n', '"Sitting"\n', '"Standing"\n']
       self.name_to_idx = dict([(v, k) for k, v in enumerate(self.idx_to_name)])
       self.img_dir = img_dir
       self.ann_dir = ann_dir
       # self.frame = frame
       self.batch_size = batch_size
       self.dim = dim



   def _get_annotation(self, file, j):
       
       frame_map = dict()
       with open(file, 'r') as fp:
           line = fp.readline()

       while line:
           line_split = line.split(' ')
           frame_id = int(line_split[5])
           if line_split[10] is not None:
               label = line_split[10]
           else:
               label = "None"
           val = (int(line_split[0]), list(map(int, line_split[1:5])), list(map(int, line_split[6:8])), line_split[10])
           if frame_id not in frame_map:
               frame_map[frame_id] = [val]
           else:
               frame_map[frame_id].append(val)

           line = fp.readline()
   
       for obj in frame_map[int(j)]:
           xmin = float(obj[1][0]) / 3840.0
           ymin = float(obj[1][1]) / 2160.0
           xmax = float(obj[1][2]) / 3840.0
           ymax = float(obj[1][3]) / 2160.0
           name = obj[3]
           boxes.append([xmin, ymin, xmax, ymax])

           labels.append(self.name_to_idx[name] + 1)

       return np.array(boxes, dtype=np.float32), np.array(labels, dtype=np.int64)

   def __getitem__(self, index):
       start_index = index * self.batch_size
       x_train = []
       y_train = []
       i = start_index - 1
       while len(x_train) < self.batch_size:
           try:
               for i in os.listdir(self.img_dir):
                 for j in os.listdir(self.img_dir + '/' + i):
                   img = cv2.imread(self.img_dir + '/' +i + '/' + j)
                   img = cv2.resize(img,(320,240))
                   img = np.array(img, dtype = np.float32)
                   img = img / 255.0
                   boxes, labels = self._get_annotation(self.ann_dir + '/' + i + '.txt', int(j[:-4]))
                   x_train.append(img)
                   y_train.append(boxes)
                   i += 1

           except Exception as err:
               print(err)
               continue
           
       x_train = np.array(x_train)
       y_train = np.array(y_train)
 
       return x_train, y_train
train_data = DataGenerator("/content/okutama_imgs", "/content/okutama_labels", batch_size=4)
model.compile(optimizer=optimizer, loss= SSD_loss, metrics=['accuracy'])
model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)

運行 model.fit 時會出現此錯誤。 我不明白為什么會這樣,如果您需要更多信息,我很樂意提供給您。 這是錯誤

NotImplementedError                       Traceback (most recent call last)
<ipython-input-64-7c278c6b3232> in <module>()
----> 1 model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)

4 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/data_utils.py in __len__(self)
    456         The number of batches in the Sequence.
    457     """
--> 458     raise NotImplementedError
    459 
    460   def on_epoch_end(self):

NotImplementedError: 

此錯誤表明您沒有在繼承的 class 中實現一些需要的 function,在本例中為__len__ function。

像這樣將__len__ function 添加到DataGenerator class 中:

def __len__(self):
    #len(self.x) is the length of your input features
    return math.ceil(len(self.x) / self.batch_size) 

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM