I am using a scikit-learn MLPClassifier (neural network) to classify some images. The image data is returned as a multidimensional array of floats. If I don't convert the 1 hot-encoded labels to floats, I get a label mismatch when trying to train the model, so I convert them. However, when I go to score the predictions, now I get " 'numpy.float64' object is not iterable ". Any suggestions on how to get this to work?
import numpy as np
import sys
import pandas as pd
from skimage import io
from skimage import transform as trans
from sklearn.neural_network import MLPClassifier as NN
from sklearn.model_selection import train_test_split
#Get the data
print ("Reading CSV...")
data = pd.read_csv(filepath_or_buffer="hot_dog_data.csv", nrows=30)
X = data.values[1:,0]
Y = data.values[1:,1:8]
#convert the images to RGB number arrays
print ('Converting Images...')
img_converts = []
for line in X:
img = io.imread("./images/"+line)
img = trans.resize(img,(300,400), mode='constant')
img_converts.append(img)
X = np.array(img_converts)
# Split into train and test vars
trainX, testX, trainY, testY = train_test_split(X,Y, test_size=0.17)
# Reshape the image arrays into 2-D arrays so it will fit the model
xa, xb, xc, xd = trainX.shape
d2_trainX = trainX.reshape((xa, xb*xc*xd))
xe, xf, xg, xh = testX.shape
d2_testX = testX.reshape((xe, xf*xg*xh))
clf = NN(solver='lbfgs',hidden_layer_sizes=(5, 2), random_state=1)
# Recast the Y data so the fit won't get a label mismatch
trainY = np.asarray(trainY, dtype=np.float)
testY = np.asarray(testY, dtype=np.float)
print ('The machine is learning...')
clf.fit(d2_trainX, trainY)
print ('Predicting...')
count = 1
for line in clf.predict(d2_testX):
print (count, line )
count += 1
print 'Calculating Accuracy...'
count = 1
for x,line in clf.score(d2_testX, testY):
print (count, line)
sys.exit()
In the line
for x,line in clf.score(d2_testX, testY):
you're trying to iterate over a float value returned by score()
.
score(X, y, sample_weight=None)
...
Returns: score : float
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.