简体   繁体   中英

Error trying to classify 3D images using Naive Bayes

I've made a convolutional neural networks algorithm to classify images, and now I want to make a Naive Bayes algorithm for comparison. My images are 3D, and I think that's the cause of the error I'm getting.

The error:

raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (1776, 3)

My code:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
import numpy as np

much_data = np.load('muchdata-50-50-30-normalizado.npy', allow_pickle=True)
X = [data[0] for data in much_data]
y = [data[1] for data in much_data]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
gnb = GaussianNB()
y_pred = gnb.fit(X_train, y_train).predict(X_test)
print("Number of mislabeled points out of a total %d points : %d" % (X_test.shape[0], (y_test != y_pred).sum()))

My X[0] is in the following format:

  [[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]
  ...
  [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]

And my y[0]:

[0 1 0]

If someone can help me understand what I'm doing wrong, it will be really helpful!

Thank you so much!

By looking at your y[0] it seems like you have 3 classes in one-hot encoded format. sklearn 's machine learning algorithms in general do not accept target values in one-hot encoded format. Moreover, the input (X) to model should have the shape of (no_samples, no_features) . Therefore, you must flatten the 3D images.

  1. Get rid of one-hot encodings in the target (y) and obtain a 1D array in the format of (no_samples,) . You may achieve this by defining the 3 classes as 1 , 2 , 3 .
  2. Flatten the images. You may do this with X = [data[0].flatten() for data in much_data]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM