[英]Keras model.predict error for categorical labels
我試圖查看預測結果並使用model.predict函數打印它們,但出現錯誤:
ValueError: Error when checking model : the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[array([ 0,....
我有多個輸入,都是嵌入式的。 當我將一個輸入嵌入時,此代碼以前有效。
for i in range(100):
prediction_result = model.predict(np.array([test_text[i], test_posts[i]]))
predicted_label = labels_name[np.argmax(prediction_result)]
print(text_data.iloc[i][:100], "")
print('Actual label:' + tags_test.iloc[i])
print("Predicted label: " + predicted_label + "\n")
test_text和test_posts是pad_sequences的結果。 它們在數組中,test_text的形狀為100,test_posts的形狀為1。labels_name是標簽的名稱。 我在第二行中有錯誤;
prediction_result = model.predict(np.array([test_text[i], test_posts[i]]))
錯誤:
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
1815 x = _standardize_input_data(x, self._feed_input_names,
1816 self._feed_input_shapes,
-> 1817 check_batch_axis=False)
1818 if self.stateful:
1819 if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
84 'Expected to see ' + str(len(names)) + ' array(s), '
85 'but instead got the following list of ' +
---> 86 str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
87 elif len(names) > 1:
88 raise ValueError(
ValueError: Error when checking model : the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[array([ 0, 0, ...
它看起來像一個簡單的解決方案,但我找不到。 謝謝您的幫助。
該模型需要兩個數組,並且您要傳遞一個單獨的numpy數組。
prediction_result = model.predict([test_text.values[i].reshape(-1,100), test_posts.values[i].reshape(-1,1)])
刪除調用numpy.array方法,您錯誤就會消失。
更新:
無需使用for loop
。
prediction_result = model.predict([test_text.values.reshape(-1,100), test_posts.values.reshape(-1,1)])
這可以做您想要的。 現在的(number rows in test_text,number of outputs)
形狀為(number rows in test_text,number of outputs)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.