![](/img/trans.png)
[英]Can I use a neural network on a linear regression using Keras? If yes , How?
[英]How to do bagging using scikit BaggingClassifier with keras Convolutional Neural Network as base estimator via keras-scikit wrapper?
我正在嘗試進行集成學習,即使用 scikit-learn BaggingClassifier 和 2D 卷積神經網絡 (CNN) 作為基本估計器進行裝袋。
在此之前,我曾嘗試使用 scikit 的神經網絡進行裝袋以測試 scikit 的 BaggingClassifier 並且它有效。 我還用 keras-wrapper 測試了 scikit 的 GridSearchCV 來搜索 2D CNN 的超參數,它也有效。
剛才,當我嘗試使用 scikit 的 BaggingClassifier 和 keras-wrapper 進行包裝,然后使用 2D CNN 模型作為基礎估計器創建集成學習時,出現錯誤。
這是代碼片段:
def baggingCNN(self):
from sklearn.ensemble import BaggingClassifier
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils.np_utils import to_categorical
patternTraining = np.reshape(self.patternTraining,
(self.patternTraining.shape[0], 1, 1, self.patternTraining.shape[1]))
patternTesting = np.reshape(self.patternTesting,
(self.patternTesting.shape[0], 1, 1, self.patternTesting.shape[1]))
X = patternTraining
Y_binary = to_categorical(self.targetTraining)
cnnA=KerasClassifier(self.create_cnn_model_A(patternTraining.shape[1],patternTraining.shape[2],patternTraining.shape[3]),nb_epoch=500, batch_size=64, verbose=1)
bagging=BaggingClassifier(base_estimator=cnnA, n_estimators=3, verbose=1, n_jobs=3, max_samples=1)
bagging.fit(X, Y_binary)
這是 create_cnn_model_A 函數的樣子:
def create_cnn_model_A(self, sizeDepth, sizeRow, sizeCol):
from keras.models import Sequential
import keras.layers.core as core
import keras.layers.convolutional as conv
from keras.regularizers import l2, activity_l2, l1, activity_l1, l1l2, activity_l1l2
numFilter = 32
nStride = 1
model = Sequential()
model.add(conv.Convolution2D(nb_filter=numFilter, nb_row=1, nb_col=2, activation='relu',
input_shape=(sizeDepth, sizeRow, sizeCol), border_mode='same'))
model.add(conv.Convolution2D(nb_filter=numFilter, nb_row=1, nb_col=3, activation='relu',
input_shape=(sizeDepth, sizeRow, sizeCol), border_mode='same'))
model.add(conv.Convolution2D(nb_filter=numFilter, nb_row=1, nb_col=4, activation='relu',
input_shape=(sizeDepth, sizeRow, sizeCol), border_mode='same'))
model.add(conv.MaxPooling2D(pool_size=(1, 2), strides=(nStride, nStride), dim_ordering="th"))
model.add(conv.Convolution2D(nb_filter=numFilter, nb_row=1, nb_col=2, activation='relu',
input_shape=(sizeDepth, sizeRow, sizeCol), border_mode='same'))
model.add(conv.Convolution2D(nb_filter=numFilter, nb_row=1, nb_col=2, activation='relu',
input_shape=(sizeDepth, sizeRow, sizeCol), border_mode='same'))
model.add(conv.MaxPooling2D(pool_size=(1, 2), strides=(nStride, nStride), dim_ordering="th"))
model.add(core.Flatten())
model.add(core.Dense(output_dim=50, activation='relu', W_regularizer=l2(0.01)))
model.add(core.Dense(output_dim=18, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'precision', 'recall'])
return model
這是重塑前 self.patternTraining 和 self.targetTraining 的形狀:
(1361, 45) (1361,)
這是我得到的錯誤:
Traceback (most recent call last):
File "/home/berylramadhian/PycharmProjects/Relation Extraction/TestModule2.py", line 153, in <module>
clsf.baggingCNN()
File "/home/berylramadhian/PycharmProjects/Relation Extraction/MachineLearning.py", line 511, in baggingCNN
bagging.fit(X, Y_binary)
File "/usr/local/lib/python2.7/dist-packages/sklearn/ensemble/bagging.py", line 248, in fit
return self._fit(X, y, self.max_samples, sample_weight=sample_weight)
File "/usr/local/lib/python2.7/dist-packages/sklearn/ensemble/bagging.py", line 284, in _fit
X, y = check_X_y(X, y, ['csr', 'csc'])
File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.py", line 521, in check_X_y
ensure_min_features, warn_on_dtype, estimator)
File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.py", line 405, in check_array
% (array.ndim, estimator_name))
ValueError: Found array with dim 4. Estimator expected <= 2.
我認為這是某種數組形狀錯誤,但我不知道如何解決這個問題。 或者也許還不可能通過 keras-wrapper 將 scikit 的 BaggingClassifier 與 keras 的 2D CNN 一起使用?
如果需要更多詳細信息,我已准備好提供。 任何幫助表示贊賞,謝謝。
使用 Keras 的 sklearn 當前不支持此功能。 您必須自己實施,否則我幾個月前在社區中提出了同樣的問題。 我得到了一個恰當的回應,目前,他們正在嘗試實施它。 等待下一個版本或檢查問題以了解更多信息。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.