简体   繁体   中英

Model.fit in keras with multi-label classification

I'm trying to learn how to implement my own dataset on the model seen here: resnet which is just a resnet model written in keras. Within the code they write this line

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

and then use the respective data to 'Convert class vectors to binary class matrices.'

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

and then pass these values into the fit function for the model that was built like so:

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(x_test, y_test),
          shuffle=True,
          callbacks=callbacks)

I believe that I can create the x_train by doing something similar to(assumes i have an array of image paths):

#pseudocode
x_train = nparray
for image in images:
    im = PIL.Image.open(image).asNumpy() 
    x_train.append(im)

Is the above correct?

As for y_train I do not quite understand what is being passed into model.fit, is it an array of one hot encoded arrays? So if I had 3 images containing; a cat and dog, a dog, a cat respectively would the y_train be

[
 [1, 1, 0],#cat and dog
 [0, 1, 0],#dog
 [1, 0, 0]#cat
]

or am I mistaken on this as well?

So, model.fit() expects x_train as the features and y_train as the labels for a particular classification problem. I'll be taking into consideration multiclass image classification .

  • x_train : For image classification, this argument will have the shape (num_images, width, height, num_channels ) . Where num_images refers to the number of images present in a training batch. See here .

  • y_train : The labels which are one-hot encoded. The required shape is (num_images, num_classes ) .

Notice the num_images is common in both the arguments. You need to take care to ensure that there is an equal number of images and labels.

Hope that helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM