簡體   English   中英

如何將修復超參數作為 Keras-Tuner 的變量傳遞?

[英]How to pass fix hyperparameters as variables for Keras-Tuner?

我想使用 Keras 調諧器對 Keras model 進行超參數調整。

import tensorflow as tf
from tensorflow import keras
import keras_tuner as kt

def model_builder(hp):

  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

tuner.search(train_X, train_y, epochs=50)

到目前為止,一切都很好。 但是,我還想定義一些 model 參數(如輸入圖像尺寸)作為model_builder的輸入參數,我一無所知,該怎么做:

def model_builder(hp, img_dim1, img_dim2):

  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(img_dim1, img_dim2)))
...

tuner = kt.Hyperband(model_builder(img_dim1, img_dim2),
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

似乎不起作用。 如何將img_dim1, img_dim2饋送到hp以外的 model ?

一個簡單的解決方案是在 python 中使用“部分函數”,如下所示:

from functools import partial

#...

model_builder_ready = partial(model_builder, img_dim1 = value1, img_dim2 = value2)

tuner = kt.Hyperband(model_builder_ready,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

我想出的解決方案是創建一個 function ,它返回一個 function (可能是 partial 的作用),所以這應該是這樣的:

def model_builder(img_dim1, img_dim2):
    def func(hp):
        """
        Your original builder but here img_dim1 and img_dim2 exist in the scope so you can use them as parameter
        """
    return func

tuner = kt.Hyperband(model_builder(img_dim1, img_dim2),
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM