簡體   English   中英

AttributeError: 'History' 對象沒有屬性 'predict_classes'

[英]AttributeError: 'History' object has no attribute 'predict_classes'

我正在嘗試使用 keras 創建一個分類器,但由於某種原因我無法從我的測試集中生成一些類預測。 為此,我使用了以下模型。

def get_model():          #takes ch1, ch2, y_train

nclass = 6
#Channel 1
ch1_input = Input(shape=X_train_ch1[0].shape)       #(3000,1)
signal_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(ch1_input)
signal_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(signal_1)
signal_1 = MaxPool1D(pool_size=2)(signal_1)
signal_1 = SpatialDropout1D(rate=0.1)(signal_1)
signal_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(signal_1)
signal_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(signal_1)
signal_1 = MaxPool1D(pool_size=2)(signal_1)
signal_1 = SpatialDropout1D(rate=0.)(signal_1)
signal_1 = Convolution1D(64, kernel_size=3, activation=activations.relu, padding="valid")(signal_1)
signal_1 = Convolution1D(64, kernel_size=3, activation=activations.relu, padding="valid")(signal_1)
flatten_1 = Flatten()(signal_1)

#Channel 2
ch2_input = Input(shape=X_train_ch2[0].shape)
signal_2 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(ch2_input)
signal_2 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(signal_2)
signal_2 = MaxPool1D(pool_size=2)(signal_2)
signal_2 = SpatialDropout1D(rate=0.1)(signal_2)
signal_2 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(signal_2)
signal_2 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(signal_2)
signal_2 = MaxPool1D(pool_size=2)(signal_2)
signal_2 = SpatialDropout1D(rate=0.2)(signal_2)
signal_2 = Convolution1D(64, kernel_size=3, activation=activations.relu, padding="valid")(signal_2)
signal_2 = Convolution1D(64, kernel_size=3, activation=activations.relu, padding="valid")(signal_2)
flatten_2 = Flatten()(signal_2)

# merge CNN's being trained on each channel 
merged = concatenate([flatten_1, flatten_2])

# Output
dense_1 = Dropout(rate=0.15)(Dense(64, activation=activations.relu, name="dense_1")(merged))
#dense_1 = Dense(, activation=activations.relu)(dense_1)
dense_1 = Dropout(rate=0.25)(Dense(32, activation=activations.relu, name="dense_2")(dense_1))
dense_1 = Dense(nclass, activation=activations.softmax, name="dense_3")(dense_1)

# Compile model 
model = Model(inputs=[ch1_input, ch2_input], outputs=dense_1)
model.compile(loss='categorical_crossentropy',
          optimizer='adam',
          metrics=['accuracy'])
model.summary()
print(model.summary)
return model

#  --------------------------- Create train data and model
sequence = Standardise_and_Augment(sequence)  
X_train_ch1, X_train_ch2, X_test_ch1, X_test_ch2, X_test, y_train, y_test = Process_data()  
y_flat = np.argmax(y_train, axis=1) 
model = get_model()

#  ---------------------------  Run model 
ch_model = model.fit([X_train_ch1,X_train_ch2], y_train, epochs=20, batch_size=32, 
                 validation_split=0.2 ,class_weight='auto', shuffle = True)

#  ---------------------------  Get get class breakdown
from sklearn.metrics import classification_report

Y_test = np.argmax(y_test, axis=1) # Convert one-hot to index
y_pred = ch_model.predict_classes(X_test)
print(classification_report(Y_test, y_pred))

運行它會給出 AttributeError: 'History' object has no attribute 'predict_classes' 我知道我的模型歷史正在被存儲,因為我可以通過運行生成我的模型性能圖:

# Plot model accuracy 
plt.subplot(2, 1, 1)
plt.plot(ch_model.history['accuracy'])
plt.plot(ch_model.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.show()

到目前為止,我所看到的針對此錯誤的所有解決方案都是關於使用我確定我正在使用的順序模型。 如果有人能讓我知道我哪里出錯了或任何生成 y_pred 的替代方法,我將不勝感激。

model.fit不會返回一個模型實例,您可以在其中調用predict ,因此您在錯誤的對象上調用predict ,正確的方法是:

model.fit([X_train_ch1,X_train_ch2], y_train, epochs=20, batch_size=32, 
          validation_split=0.2 ,class_weight='auto', shuffle = True)

y_pred = model.predict_classes(X_test)

顯然我誤解了這個問題。 正如指出的那樣 model.fit 返回一個歷史對象,因此不能用於進行預測。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM