简体   繁体   中英

SageMaker Image Classification: How to get an ordered list of classes corresponding to the output of the model

I'm training a model for multi-class image classification on AWS sagemaker using a custom dataset. The dataset has around 50 classes. I'm following this notebook: Image classification transfer learning demo

According to my understanding, the final layer of the model outputs probabilities corresponding to each class in our dataset. Sagemaker expects the dataset to be provided in mxnet recordio's .rec format. Since I'm not manually converting the labels to one-hot-encoded, I don't know which layer is ouputing probabilities for which class. How can I get an ordered list of classes where indexes corresponds to the output of final layer of the model.

Even the notebook provided by AWS (Link above) has that ordered list (list: object_categories) hard-coded.

My dataset before converting to .rec format looks like this:

./train/object1/
   -image1.jpg
   -image2.jpg
   -image3.jpg
   -...image500.jpg
./train/object2/
   -image1.jpg
   -image2.jpg
   -image3.jpg
   -...image500.jpg
.
.
.
./train/object50/
   -image1.jpg
   -image2.jpg
   -image3.jpg
   -...image500.jpg

Any help will be highly appreciated.

The labels are embedded in the recordio .rec file, so for custom multi-label applications you will have to re-label. Frankly documentation isn't super, but here's a starting point (go to the bottom): https://mxnet.incubator.apache.org/faq/recordio.html

As for the layer that has the labels, the final layer is what generates the label probability. The model architecture is abstracted in Sagemaker, and setting the class hyperparameter ensures that the final layer will assign probabilities for the number of classes found in the .rec file.

You need to use im2rec to create the recordio file from the lst file. The lst file is created based on the input dataset and one-hot encoding of the labels is done in the lst file. Please refer this notebook for example of how to create multi-label input.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM