简体   繁体   中英

Labels size different from target_names: Tensorflow Multi-Input Regression converting to Classification

I am trying to convert a multi-input mixed input (txt, image) keras model from a regression output (house price) to a classification output (number of bedrooms). In particular, I am altering this tutorial

https://www.pyimagesearch.com/2019/02/04/keras-multiple-inputs-and-mixed-data/

to be a classifier. I have a couple of technical questions about the number of categories, and I also get an error that I don't fully understand.

I have altered the last layer of the network to be a softmax:

x = Dense(11, activation="softmax")(x)

However I only have 10 categories (the dataset covers houses with 1-10 bedrooms). But with Dense(10,...) I get the following error:

InvalidArgumentError: Received a label value of 10 which is outside >the valid range of [0, 10). Label values: 3 2 5 2 10 3 2 5

I understand the error, and how to avoid it, but why isn't the range [0,10) sufficient given that I don't have houses with 0 bedrooms?

When I try and get a classification report I get two warnings:

UserWarning: labels size, 6, does not match size of target_names, 10 UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.

I think these might be because my classification report only contains houses with 1-6 bedrooms. But am not sure - any insight you can give would be appreciated.

My code and the dataset can be cloned from here: https://github.com/davidrtfraser/blog-keras-multi-input

Generally in Machine Learning, labels for a N classes are encoded as integers in the range 0 to N - 1, because this maps directly from class indices, so you can use argmax to recover them from model outputs.

So you need to encode your labels in the same way, the easiest way is to substract your [1, 10] labels to [0, 9] by substracting one from each label, and to get the number of bedrooms from the model output, you add one to the output label.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM