简体   繁体   English

类型错误:fit_generator() 缺少 1 个必需的位置参数:'generator'

[英]TypeError: fit_generator() missing 1 required positional argument: 'generator'

I am trying to train a CNN to detect if an image is deepfake or not, but upon running the code I keep getting this error: TypeError: fit_generator() missing 1 required positional argument: 'generator' How do I get rid of this error?我正在尝试训练 CNN 来检测图像是否为 deepfake,但在运行代码时,我不断收到此错误:TypeError: fit_generator() missing 1 required positional argument: 'generator' 如何摆脱此错误? Is there an issue with my code?我的代码有问题吗? Im also not sure if the classifier class is necessary so i've included it but commented it out.我也不确定分类器 class 是否是必要的,所以我已将其包含在内,但已将其注释掉。

My code in full:我的完整代码:

import tensorflow as tf 
config = tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth = True 
sess = tf.compat.v1.Session(config=config)

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, Concatenate, LeakyReLU
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model

from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.models import Model

# Height and width refer to the size of the image
# Channels refers to the amount of color channels (red, green, blue)

image_dimensions = {'height':256, 'width':256, 'channels':3}

# Create a Classifier class

#class Classifier():
     
       # def __init__():
           # self.model = 0
    
        #def predict(self, x):
           # return self.model.predict(x)
    
       # def fit(self, x, y):
            #return self.model.train_on_batch(x, y)
    
       # def get_accuracy(self, x, y):
            #return self.model.test_on_batch(x, y)
    
        #def load(self, path):
           # self.model.load_weights(path)

class Meso4(Model):
    def __init__(self, learning_rate = 0.0001):
        self.model = self.init_model()
        optimizer = Adam(lr = learning_rate)
        self.model.compile(optimizer = optimizer,
                           loss = 'mean_squared_error',
                           metrics = ['accuracy'])   
    
        
    
    def init_model(self):
        x = Input(shape = (image_dimensions['height'],
                           image_dimensions['width'],
                           image_dimensions['channels']))
        
        x1 = Conv2D(8, (3, 3), padding='same', activation = 'relu')(x)
        x1 = BatchNormalization()(x1)
        x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
        
        x2 = Conv2D(8, (5, 5), padding='same', activation = 'relu')(x1)
        x2 = BatchNormalization()(x2)
        x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
        
        x3 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x2)
        x3 = BatchNormalization()(x3)
        x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
        
        x4 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x3)
        x4 = BatchNormalization()(x4)
        x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
        
        y = Flatten()(x4)
        y = Dropout(0.5)
        y = Dense(16)
        y = LeakyReLU(alpha=0.1)
        y = Dropout(0.5)
        y = Dense(1, activation = 'sigmoid')
        
        return Model(inputs = x, outputs = y)

bat_size = 64
input_size = 256

# initializing a train datagenerator
train_datagen = ImageDataGenerator(rescale=1./255)

# initializing a test datagenerator
test_datagen = ImageDataGenerator(rescale=1./255)

# preprocessing for trainig set
train_set = train_datagen.flow_from_directory(
                            'C:\\Users\\Kevin\\Desktop\\Train', # train data directory
                            target_size=(input_size, input_size), 
                            batch_size=bat_size,
                            class_mode='categorical',
                            color_mode= 'rgb'
                                            )

# preprocessing for test set
test_set = test_datagen.flow_from_directory(
                                'C:\\Users\\Kevin\\Desktop\\Test', # test data directory
                            target_size=(input_size, input_size),
                            batch_size=bat_size,
                            shuffle=False,
                            class_mode='categorical',
                            color_mode= 'rgb'
                                            )
filepath = "FYP.hdf5"
checkpoint = ModelCheckpoint(
                            filepath,
                            monitor='val_acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max'
                            )
Meso4.fit_generator(
                                train_set,
                                steps_per_epoch=1400//bat_size + 1,
                                epochs=25,
                                callbacks=[checkpoint],
                                validation_data=test_set,
                                validation_steps=600 //bat_size + 1
                                )

#ERROR
TypeError                                 Traceback (most recent call last)
<ipython-input-9-00d0b295f968> in <module>
      5                                 callbacks=[checkpoint],
      6                                 validation_data=test_set,
----> 7                                 validation_steps=600 //bat_size + 1
      8                                 )

~\Anaconda3\envs\Tf\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

TypeError: fit_generator() missing 1 required positional argument: 'generator'

3 or 4 Mistakes i can see:我可以看到 3 或 4 个错误:

For subclassing in keras:对于 keras 中的子类化:

  • You need to call super(YourClass, self).__init__()你需要调用super(YourClass, self).__init__()
  • You define your model inside a call method您在call方法中定义 model

Check thislink to learn more about keras subclassing检查此链接以了解有关 keras 子类化的更多信息

Also in your y part you stop using functional syntax同样在您的 y 部分中,您停止使用功能语法

 y = Flatten()(x4)
 y = Dropout(0.5)
 y = Dense(16)

it should be它应该是

y = Dropout(0.5)(y)
y = Dense(16)(y)

and you dont call class directly just instantiate a new object而且你不直接调用 class 只是实例化一个新的 object

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 TypeError:generator()缺少1个必需的位置参数:“ json” - TypeError: generator() missing 1 required positional argument: 'json' 类型错误:einsatznummer_generator() 缺少 1 个必需的位置参数:'self' - TypeError: einsatznummer_generator() missing 1 required positional argument: 'self' 如何修复“TypeError fit_generator() got an unexpected keyword argument 'samples_per_epoch'” - How to fix “TypeError fit_generator() got an unexpected keyword argument 'samples_per_epoch'” TypeError:fit_generator() 为参数“steps_per_epoch”获得了多个值 - TypeError: fit_generator() got multiple values for argument 'steps_per_epoch' 类型错误:fit_generator() 得到了一个意外的关键字参数“nb_val_samples” - TypeError: fit_generator() got an unexpected keyword argument 'nb_val_samples' TypeError:fit_generator() 为参数“steps_per_epoch”获得了多个值 - TypeError: fit_generator() got multiple values for argument 'steps_per_epoch' scikit-learn-TypeError:fit()缺少1个必需的位置参数:“ y” - scikit-learn - TypeError: fit() missing 1 required positional argument: 'y' StandardScaler: TypeError: fit() 缺少 1 个必需的位置参数:&#39;X&#39; - StandardScaler: TypeError: fit() missing 1 required positional argument: 'X' Python 出错:TypeError:fit() 缺少 1 个必需的位置参数:'y' - Error with Python: TypeError: fit() missing 1 required positional argument: 'y' TypeError: fit() 缺少 1 个必需的位置参数:'y' while GridSearching CNN - TypeError: fit() missing 1 required positional argument: 'y' while GridSearching CNN
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM