简体   繁体   中英

how to obtain the number of classes using tf.keras.preprocessing.image_dataset_from_directory?

img_height,img_width=180,100 batch_size=32 train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir1,validation_split=0.01,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

Output: Found 1376 files belonging to 4 classes. Using 1363 files for training.

how can I get the total number of classes in a variable?

If you have something like

train_gen=tf.keras.preprocessing.image_dataset_from_directory(etc

then you can use the code below to get the type of information you want

classes=list(train_gen.class_indices.keys())
class_indices=list(train_gen.class_indices.values())
num_of_classes=len(classes)

train_gen.class_indices is a dictionary of the form {class: index}

label_map = (train.ds.class_indices)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM