"Unsupported number of image dimensions" while using image_utils from Transformers

I'm trying to follow this HuggingFace tutorial https://huggingface.co/blog/fine-tune-vit

Using their "beans" dataset everything works, but if I use my own dataset with my own images, I'm hitting "Unsupported number of image dimensions". I'm wondering if anyone here would have pointers for how to debug this.

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_2042949/883871373.py in <module>
----> 1 train_results = trainer.train()
      2 trainer.save_model()
      3 trainer.log_metrics("train", train_results.metrics)
      4 trainer.save_metrics("train", train_results.metrics)
      5 trainer.save_state()

~/miniconda3/lib/python3.9/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1532             self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1533         )
-> 1534         return inner_training_loop(
   1535             args=args,
   1536             resume_from_checkpoint=resume_from_checkpoint,

~/miniconda3/lib/python3.9/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1755             step = -1
-> 1756             for step, inputs in enumerate(epoch_iterator):
   1758                 # Skip past any already trained steps if resuming training

~/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
--> 119         raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
    121     if image.shape[first_dim] in (1, 3):

ValueError: Unsupported number of image dimensions: 2


I tried looking at the shape of my data and theirs and it's the same.

$ prepared_ds['train'][0:2]['pixel_values'].shape
torch.Size([2, 3, 224, 224])

I followed the stack trace and found that the error was in the infer_channel_dimension_format function, so I wrote this filth to find the problematic image:

from transformers.image_utils import infer_channel_dimension_format
    for i, img in enumerate(prepared_ds["train"]):
except ValueError as ve:

When I inspected that image, I saw that its not RGB like the others.

$ ds["train"][8]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=L size=390x540>,
 'image_file_path': '/data/alamy/img/00000/000001069.jpg',
 'labels': 0}

So the solution for me was to add a convert('RGB') to my transform:

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x.convert("RGB") for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

I was facing the same error today, after using the collate function, the above error was solved,

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])

