简体   繁体   中英

Using albumentation's augmentation in tensorflow dataset API is giving this error : Incompatible shapes expected [?,224,224,3] but got [8,1,224,224,3]

I am getting this error while trying to augment the images using the albumentations library which uses tf.numpy_function to wrap the python function for augmentation in tensorflow from this link : https://albumentations.ai/docs/examples/tensorflow-example/

I have loaded my dataset of images and target label using tensorflow dataset API. The code :

img_paths = df['image_path'].values
target = df['target_label'].values

path_lis = tf.data.Dataset.from_tensor_slices(img_paths)
target_lis = tf.data.Dataset.from_tensor_slices(target)
list_ds = tf.data.Dataset.zip((path_lis, target_lis))

image_count = len(df)
val_size = int(image_count * 0.3)
train = list_ds.skip(val_size)
val = list_ds.take(val_size)


def process_path(file_path, target):

  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = tf.image.decode_jpeg(img, channels=3)

  return img, target

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_data = train.map(process_path, num_parallel_calls=AUTOTUNE)
val_data = val.map(process_path, num_parallel_calls=AUTOTUNE)

# Augmentation using albumentations library
transforms = A.Compose([
            A.Rotate(limit=40),
            A.RandomBrightness(limit=0.1),
            A.RandomContrast(limit=0.9, p=1),
            A.HorizontalFlip(),
            A.Resize(224, 224)
            ])

def aug_fn(image):

    data = {"image": image}
    aug_data = transforms(**data)
    aug_img = aug_data["image"]
    #target = aug_data["keypoints"][0]
    aug_img = tf.cast(aug_img/255.0, tf.float32)
    #aug_img = tf.image.resize(aug_img, size=[224, 224])

    return aug_img

def process_aug(img, label):

    aug_img = tf.numpy_function(func=aug_fn, inp=[img], Tout=[tf.float32])
    return aug_img, label

# create dataset
train_ds = train_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

def set_shapes(img, label):

    img.set_shape([224, 224, 3])
    label.set_shape([])

    return img, label

train_ds = train_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
val_ds = val_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)


def view_image(ds):

    image, label = next(iter(ds)) # extract 1 batch from the dataset
    image = image.numpy()
    label = label.numpy()

    fig = plt.figure(figsize=(22, 22))
    for i in range(20):
        ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
        ax.imshow(image[i])
        ax.set_title(f"Label: {label[i]}")

view_image(train_ds)

The full error message:

Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2102, in execution_mode
    yield
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 758, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py", line 2610, in iterator_get_next
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 6843, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [4,1,224,224,3]. [Op:IteratorGetNext]

During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-20-23a37450bee7>", line 13, in <module>
    view_image(train_ds)
  File "<ipython-input-20-23a37450bee7>", line 3, in view_image
    image, label = next(iter(ds)) # extract 1 batch from the dataset
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 736, in __next__
    return self.next()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 772, in next
    return self._next_internal()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 764, in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec, ret)
  File "C:\Users\Arun\Anaconda3\lib\contextlib.py", line 99, in __exit__
    self.gen.throw(type, value, traceback)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2105, in execution_mode
    executor_new.wait()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\executor.py", line 67, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [8,1,224,224,3].

Can someone can at least tell me why this error is happening? Thanks in advance!

img_shape should be (120,120,3) not [224, 224, 3]

for example:

def set_shapes(img, label, img_shape=(120,120,3)):
    img.set_shape(img_shape)
    label.set_shape([])
    return img, label

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM