簡體   English   中英

Keras:進行超參數網格搜索時內存不足

[英]Keras: Out of memory when doing hyper parameter grid search

我正在運行多個嵌套循環來進行超參數網格搜索。 每個嵌套循環遍歷超級參數值列表,並且在最內層循環內部,每次使用生成器構建和評估Keras順序模型。 (我沒有做任何訓練,我只是隨機初始化,然后多次評估模型,然后檢索平均損失)。

我的問題是,在這個過程中,Keras似乎填滿了我的GPU內存,所以我最終得到了一個OOM錯誤。

在評估模型后,是否有人知道如何解決這個問題並釋放GPU內存?

在評估之后我根本不需要模型,我可以在內循環的下一次傳遞中構建一個新模型之前完全拋棄它。

我正在使用Tensorflow后端。

這是代碼,盡管其中大部分與一般問題無關。 該模型構建在第四個循環內,

for fsize in fsizes:

我想有關如何構建模型的細節並不重要,但無論如何都是這樣的:

model_losses = []
model_names = []

for activation in activations:
    for i in range(len(layer_structures)):
        for width in layer_widths[i]:
            for fsize in fsizes:

                model_name = "test_{}_struc-{}_width-{}_fsize-{}".format(activation,i,np.array_str(np.array(width)),fsize)
                model_names.append(model_name)
                print("Testing new model: ", model_name)

                #Structure for this network
                structure = layer_structures[i]

                row, col, ch = 80, 160, 3  # Input image format

                model = Sequential()

                model.add(Lambda(lambda x: x/127.5 - 1.,
                          input_shape=(row, col, ch),
                          output_shape=(row, col, ch)))

                for j in range(len(structure)):
                    if structure[j] == 'conv':
                        model.add(Convolution2D(width[j], fsize, fsize))
                        model.add(BatchNormalization(axis=3, momentum=0.99))
                        if activation == 'relu':
                            model.add(Activation('relu'))
                        if activation == 'elu':
                            model.add(ELU())
                            model.add(MaxPooling2D())
                    elif structure[j] == 'dense':
                        if structure[j-1] == 'dense':
                            model.add(Dense(width[j]))
                            model.add(BatchNormalization(axis=1, momentum=0.99))
                            if activation == 'relu':
                                model.add(Activation('relu'))
                            elif activation == 'elu':
                                model.add(ELU())
                        else:
                            model.add(Flatten())
                            model.add(Dense(width[j]))
                            model.add(BatchNormalization(axis=1, momentum=0.99))
                            if activation == 'relu':
                                model.add(Activation('relu'))
                            elif activation == 'elu':
                                model.add(ELU())

                model.add(Dense(1))

                average_loss = 0
                for k in range(5):
                    model.compile(optimizer="adam", loss="mse")
                    val_generator = generate_batch(X_val, y_val, resize=(160,80))
                    loss = model.evaluate_generator(val_generator, len(y_val))
                    average_loss += loss

                average_loss /= 5

                model_losses.append(average_loss)

                print("Average loss after 5 initializations: {:.3f}".format(average_loss))
                print()

如圖所示,使用的后端是Tensorflow。 使用Tensorflow后端,當前模型不會被破壞,因此您需要清除會話。

使用完模型之后只需:

if K.backend() == 'tensorflow':
    K.clear_session()

包括后端:

from keras import backend as K

您也可以使用sklearn包裝器進行網格搜索。 檢查這個例子: 這里 此外,對於更高級的超參數搜索,您可以使用hyperas

使用indraforyou給出的提示,我添加了代碼以清除我傳遞給GridSearchCV的函數內的TensorFlow會話,如下所示:

def create_model():
    # cleanup
    K.clear_session()

    inputs = Input(shape=(4096,))
    x = Dense(2048, activation='relu')(inputs)
    p = Dense(2, activation='sigmoid')(x)
    model = Model(input=inputs, outputs=p)
    model.compile(optimizer='SGD',
              loss='mse',
              metrics=['accuracy'])
    return model

然后我可以調用網格搜索:

model = KerasClassifier(build_fn=create_model)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=1)

它應該工作。

干杯!

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM