简体   繁体   中英

ValueError: too many values to unpack (expected 2) when using tf.keras.preprocessing.image_dataset_from_directory

I want to create a dataset-variable as well as a labels-variable using the function tf.keras.preprocessing.image_dataset_from_directory ( https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory ). The documentation states:

Returns: A tf.data.Dataset object. If label_mode is None, it yields float32 tensors of shape (batch_size, image_size[0], image_size[1], num_channels), encoding images (see below for rules regarding num_channels). Otherwise, it yields a tuple (images, labels), where images has shape (batch_size, image_size[0], image_size[1], num_channels), and labels follows the format described below.

My code is the following:

train_ds, labels = tf.keras.preprocessing.image_dataset_from_directory(
  directory = data_dir,
  labels='inferred',
  label_mode = "int",
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

I expect to get a tuple as return values, but instead I get the error message:

Found 2160 files belonging to 2160 classes.
Using 1728 files for training.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-168-ed9d42ed2ab9> in <module>
      7   seed=123,
      8   image_size=(img_height, img_width),
----> 9   batch_size=batch_size)

ValueError: too many values to unpack (expected 2)

When I save the output in one variable (just train_ds) and I inspect the variable, I get the following output:

<BatchDataset shapes: ((None, 120, 30, 3), (None,)), types: (tf.float32, tf.int32)>

How can I access the two tuples inside seperatly?

You can plot it with the code below

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

What actually the code does is printing nine images from your dataset and adding title to each image.
Note that there is no need to get labels in your first line of code.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM