[英]one class classification with keras
我正在尝试建立一个模型来检测输入图像是否为某些图像(例如,狗是否为)。 我正在用keras编码,但准确性很差。 您有什么想法可以正确调整吗? 还是我应该使用除keras之外的其他工具来解决一类分类问题? 提前非常感谢您。
这是我到目前为止编写的代码和输出。
train_dir = './path/to/train_dir'
vali_dir = './path/to/validation_dir'
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=False)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')
vali_datagen = ImageDataGenerator(rescale=1./255)
vali_generator = vali_datagen.flow_from_directory(
vali_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')
model = Sequential()
model.add(Conv2D(16, 3, activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPool2D(pool_size=2))
model.add(Conv2D(32, 3, activation='relu'))
model.add(MaxPool2D(pool_size=2))
model.add(Conv2D(64, 3, activation='relu'))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))
model.compile(
loss='binary_crossentropy',
optimizer=RMSprop(lr=0.003),
metrics=['acc']
)
history = model.fit_generator(
train_generator,
steps_per_epoch=100,
epochs=8,
verbose=2,
validation_data=vali_generator,
validation_steps=20
)
输出:
Found 3379 images belonging to 2 classes.
Found 607 images belonging to 2 classes.
Epoch 1/8
- 136s - loss: 7.6617 - acc: 0.5158 - val_loss: 10.5220 - val_acc: 0.3400
Epoch 2/8
- 124s - loss: 7.7837 - acc: 0.5118 - val_loss: 10.5220 - val_acc: 0.3400
.......and this is just terrible.
看起来类标签有问题-它们与数据正确相关吗? 您可以检查它或发布ImageDataGenerator代码
即使从第一个时期开始,火车的准确性和验证的准确性仍存在很大差异。 在我看来,这似乎是一个过度训练的问题。 因此,您应该为网络提供更多正规化。 就像卷积层内部的更多Dropoutlayers或kernel_regularizer
一样。
我尝试更改和调整参数和训练数据,但没有得到理想的结果。 我遇到了一个使用Isolation forest
类分类。 这就是所谓的新颖性检测,在我使用它之后,它的表现非常出色。 感谢那些在评论中建议我的人,很抱歉我自己回答。
如果输入要素之间不存在依赖关系,则隔离林是用于异常检测的良好算法。 但是,如果您输入的是时间序列信号或图像,则最好使用RNN或CNN之类的方法。
我最近遇到了一个名为CNN类的异常检测模型。 如果您输入的是图像或时间序列信号,则效果很好。 这是他们的github的链接:
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.