简体   繁体   English

Scikit-Learn 包装器和 RandomizedSearchCV:RuntimeError

[英]Scikit-Learn wrapper and RandomizedSearchCV: RuntimeError

I am reading the book我在看书

" Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow: Concepts, Tools, and Techniques to Build Intelligent Systems " Scikit-Learn、Keras 和 Tensorflow 的动手机器学习:构建智能系统的概念、工具和技术

and in the Chapter 11 ( Introduction to ANN with Keras ) is explained that one can wrap a tensorflow model in scikit-learn to use some useful tools, like RandomizedSearchCV which is quite useful for random search of ANN hyperparameters (ANN structure, learning rate, activation functions, etc)在第 11 章( Keras 的 ANN 简介)中解释了可以将 tensorflow model 包装在 scikit-learn 中以使用一些有用的工具,例如RandomizedSearchCV ,它对于随机搜索 ANN 超参数(ANN 结构、学习率、激活函数)非常有用, ETC)

But I get a strange error at the end of the Randomized Search.但是我在随机搜索结束时遇到了一个奇怪的错误。 Specifically, after the random search, at the end of every combinations I get this:具体来说,在随机搜索之后,在每个组合的末尾我得到这个:

RuntimeError                              Traceback (most recent call last)
<ipython-input-35-094d1018c18c> in <module>()
     13 rnd_search_cv.fit(X_train, y_train, epochs=100, 
     14                   validation_data=(X_valid, y_valid),
---> 15                   callbacks=[keras.callbacks.EarlyStopping(patience=10)])

1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    734             # of the params are estimators as well.
    735             self.best_estimator_ = clone(clone(base_estimator).set_params(
--> 736                 **self.best_params_))
    737             refit_start_time = time.time()
    738             if y is not None:

/usr/local/lib/python3.6/dist-packages/sklearn/base.py in clone(estimator, safe)
     80             raise RuntimeError('Cannot clone object %s, as the constructor '
     81                                'either does not set or modifies parameter %s' %
---> 82                                (estimator, name))
     83     return new_object
     84 

RuntimeError: Cannot clone object <tensorflow.python.keras.wrappers.scikit_learn.KerasRegressor object at 0x7f16ce468fd0>, as the constructor either does not set or modifies parameter learning_rate

I followed every step in the chapter, namely:我遵循了本章中的每一步,即:

Function for model parameterization Function为model参数化

# build model given a set of parameters
input_shape = X_train[0].shape
X_new = X_test[:3]

def build_model(n_hidden=1, n_neurons=30, learning_rate=3e-3):
    model = keras.models.Sequential()
    model.add(keras.layers.InputLayer(input_shape=input_shape))
    for layer in range(n_hidden):
        model.add(keras.layers.Dense(n_neurons, activation="relu"))
    model.add(keras.layers.Dense(1))
    optimizer = keras.optimizers.SGD(lr=learning_rate)
    model.compile(optimizer=optimizer, loss="mse")
    return model

Scikit model wrapper Scikit model 包装器

keras_reg = keras.wrappers.scikit_learn.KerasRegressor(build_model)

I also tested the model, and it worked just fine我还测试了 model,它工作得很好

keras_reg.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid), 
              callbacks=[keras.callbacks.EarlyStopping(patience=10)])
mse_test = keras_reg.score(X_test, y_test)
y_pred = keras_reg.predict(X_new)

But when I used the RandomizedSearchCV但是当我使用 RandomizedSearchCV

# use RandomSearch (or grid search)

from scipy.stats import reciprocal
from sklearn.model_selection import RandomizedSearchCV

param_distribs = {
    "n_hidden": [0, 1, 2, 3],
    "n_neurons": np.arange(1, 100),
    "learning_rate": reciprocal(3e-4, 3e-2)
}

rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, n_iter=10, cv=3)
rnd_search_cv.fit(X_train, y_train, epochs=100, 
                  validation_data=(X_valid, y_valid),
                  callbacks=[keras.callbacks.EarlyStopping(patience=10)])

I get the above RuntimeError .我得到上面的RuntimeError

I am working on colab , with tensorflow 2.3.0我正在使用colab ,使用 tensorflow 2.3.0

import tensorflow as tf
tf.__version__

2.3.0

Does someone know why?有人知道为什么吗?

I've had the same issue, and it seems to arise from not assigning iterable values in the param_distribs dictionary (or at least, values that Scikit-Learn views as iterable).我遇到了同样的问题,这似乎是由于没有在 param_distribs 字典中分配可迭代的值(或者至少是 Scikit-Learn 认为可迭代的值)。 One way I've found to work around this is to replace these values with iterable equivalents:我发现解决这个问题的一种方法是用可迭代的等价物替换这些值:

param_distribs = {
"n_hidden": [0, 1, 2, 3],
"n_neurons": np.arange(1, 100).tolist(),
"learning_rate": np.arange(3e-4, 3e-2).tolist()
}

While this doesn't exactly reproduce Géron's code, it does seem to work!虽然这并不能完全重现 Géron 的代码,但它似乎确实有效!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 在 scikit-learn 中使用 RandomizedSearchCV 对超参数进行条件调整 - Conditional tuning of hyperparameters with RandomizedSearchCV in scikit-learn Scikit-Learn 中的 RandomizedSearchCV 溢出错误 - OverflowError while RandomizedSearchCV in Scikit-Learn 如何使用 Scikit-Learn 的 RandomizedSearchCV 和 Tensorflow 的 ImageDataGenerator - How to use RandomizedSearchCV of Scikit-Learn with ImageDataGenerator of Tensorflow 在 scikit-learn 的 RandomizedSearchCV 中使用保留集进行验证? - Using hold-out-set for validation in RandomizedSearchCV in scikit-learn? 在scikit-learn中将RandomizedSearchCV(或GridSearcCV)与LeaveOneGroupOut交叉验证相结合 - Combining RandomizedSearchCV (or GridSearcCV) with LeaveOneGroupOut cross validation in scikit-learn Scikit学习-您可以在没有交叉验证的情况下运行RandomizedSearchCV吗? - Scikit-learn - Can you run RandomizedSearchCV without cross validation? 运行时错误:无法克隆对象:Scikit-Learn 自定义估算器 - RuntimeError: Cannot clone object: Scikit-Learn custom estimator scikit-learn:如何在由一个列表组成的嵌套列表上使用 RandomizedSearchCV? - scikit-learn: How to use RandomizedSearchCV on a nested list consisting of one list? 使用 RandomizedSearchCV (scikit-learn) 优化隐藏层和神经元的数量 -&gt; 没有不必要的训练? - Optimize number of hidden layers and neurons with RandomizedSearchCV (scikit-learn) -> No unnecessary trainings? 为什么我使用 Scikit-learn 的 RandomizedSearchCV 使位置索引器越界错误 - Why am I getting a positional indexer out of bounds error with Scikit-learn's RandomizedSearchCV
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM