简体   繁体   English

Keras不打印输出,高内存和CPU使用率,并且在使用张量板回调时不使用GPU

[英]Keras prints no output, high memory and CPU usage and GPU is not used when using tensorboard Callback

I have a wierd situation in Keras and it really freaks me out. 我在Keras有一个奇怪的情况,它真的让我很烦恼。 I am trying to train a CNN using pretrained Inception with additional convolution, global average pool and dense layers. 我正在尝试使用预训练的Inception以及额外的卷积,全局平均池和密集层来训练CNN。 I am using a ImageDataGenerator to load the data. 我正在使用ImageDataGenerator来加载数据。

The data generator is working fine, I have tested that. 数据生成器工作正常,我已经测试过了。 The model compiles well also. 该模型编译得也很好。 But when I run fit_generator, no output is printed, the CPU is at 100% and memory starts filling up slowly until it overflows. 但是当我运行fit_generator时,没有打印输出,CPU处于100%并且内存开始缓慢填满,直到溢出。 And although I have a GPU and have worked with it in tensorflow (which is the backend here) a number of times, it is completely ignored by Keras. 虽然我有一个GPU并且已经在tensorflow(这里是后端)中使用了很多次,但它完全被Keras忽略了。

Considering that maybe batch size could be a problem, I set it to 1 but it did not solve the issue. 考虑到批量大小可能是一个问题,我将其设置为1,但它没有解决问题。 The images are of size 299x299, which is not that big anyway. 图像大小为299x299,无论如何都不是那么大。

I will post the code below as a reference though it seems to me that nothing is wrong with it: 我将发布下面的代码作为参考,虽然在我看来它没有任何问题:

def get_datagen():
    return ImageDataGenerator(rotation_range=30,
                        width_shift_range=0.2,
                        height_shift_range=0.2,
                        horizontal_flip=True,
                        fill_mode='nearest'
                        )

# Setup and compile the model.
model = InceptionV3(include_top=False, input_shape=(None, None, 3))

# Set the model layers to be untrainable
for layer in model.layers:
    layer.trainable = False

x = model.output
x = Conv2D(120, 5, activation='relu')(x)
x = GlobalAveragePooling2D()(x)
predictions = Activation('softmax')(x)

model_final = Model(inputs=model.inputs, outputs=predictions)

model_final.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy'])

# Define the dataflow.
train_gen = get_datagen()
val_test_gen = get_datagen()

train_data = train_gen.flow_from_directory(train_folder, target_size=(299, 299), batch_size=1)
val_data = val_test_gen.flow_from_directory(validation_folder, target_size=(299, 299), batch_size=1)
test_data = val_test_gen.flow_from_directory(test_folder, target_size=(299, 299), batch_size=1)

train_size = train_data.n
val_size = val_data.n
test_size = test_data.n

# Define callbacks.
model_checkpoint = ModelCheckpoint('../models/dbc1/', monitor='val_accuracy', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, verbose=1, mode='max')
tensorboard = TensorBoard(log_dir='../log/dbc1', histogram_freq=1, write_grads=True, )


model_final.fit_generator(train_data, steps_per_epoch=1, epochs=100, 
                          callbacks=[model_checkpoint, early_stopping, tensorboard],
                         validation_data=val_data, verbose=1)

EDIT 编辑

It seems the tensorboard callback was the problem here. 似乎张量板回调是这里的问题。 When I remove it, everything works. 当我删除它,一切正常。 Does anyone know why this is happening? 有谁知道为什么会这样?

There's seems to be a problem (possibly related to keras#3358 ) when using the histogram_freq=1 under certain conditions. 在某些条件下使用histogram_freq=1时似乎存在问题(可能与keras#3358有关 )。

You could try to set histogram_freq=0 and submit an issue at keras repository. 您可以尝试设置histogram_freq=0并在keras存储库中提交问题。 You wouldn't have the gradient histograms, but at least you would be able to train: 你不会有梯度直方图,但至少你可以训练:

model.fit(...,
          callbacks=[
              TensorBoard(log_dir='./logs/', batch_size=batch_size),
              ...
          ])

I notice that this problem doesn't happen with all trained models. 我注意到所有训练过的模型都不会发生这个问题。 If InceptionV3 usage is not a requirement, I recommend you switching to another model. 如果不要求使用InceptionV3 ,我建议您切换到另一个型号。 So far, I found that the following code (adapted from yours, using VGG19 ) works on keras==2.1.2 , tensorflow==1.4.1 : 到目前为止,我发现以下代码(改编自您的,使用VGG19 )适用于keras==2.1.2tensorflow==1.4.1

from keras.applications import VGG19
from keras.applications.vgg19 import preprocess_input

input_shape = (224, 224, 3)
batch_size = 1

model = VGG19(include_top=False, input_shape=input_shape)
for layer in model.layers:
    layer.trainable = False

x, y = model.input, model.output
y = Conv2D(2, 5, activation='relu')(y)
y = GlobalAveragePooling2D()(y)
y = Activation('softmax')(y)

model = Model(inputs=model.inputs, outputs=y)
model.compile('adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

g = ImageDataGenerator(rotation_range=30,
                       width_shift_range=0.2,
                       height_shift_range=0.2,
                       horizontal_flip=True,
                       preprocessing_function=preprocess_input)

train_data = g.flow_from_directory(train_folder,
                                   target_size=input_shape[:2],
                                   batch_size=batch_size)
val_data = g.flow_from_directory(validation_folder,
                                 target_size=input_shape[:2],
                                 batch_size=batch_size)
test_data = g.flow_from_directory(test_folder,
                                  target_size=input_shape[:2],
                                  batch_size=batch_size)

model.fit_generator(train_data, steps_per_epoch=1, epochs=100,
                    validation_data=val_data, verbose=1,
                    callbacks=[
                        ModelCheckpoint('./ckpt.hdf5',
                                        monitor='val_accuracy',
                                        verbose=1,
                                        save_best_only=True),
                        EarlyStopping(patience=3, verbose=1),
                        TensorBoard(log_dir='./logs/',
                                    batch_size=batch_size,
                                    histogram_freq=1,
                                    write_grads=True)])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM