I create train and test dataset using Keras ImageDataGenerator.flow_from_directory(...). Then I want to use these data fit model.fit()
. In Tensorflow 2.1 it works perfectly fine. However, running the same code in Tensorflow 2.2 generates: TypeError: data type not understood
. How would you suggest to overcome this issue and run it TF2.2?
Code sample:
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')
...
model.fit(train_data, epochs=50) # This generates an error in TF2.2, but in TF2.1 works fine.
Another way of generating this error in TF2.2 is iterating over the generator:
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')
for x,y in train_data:
print(type(x), type(y))
The problem was with keras versions. The following configuration caused an error.
keras 2.3.1
keras-preprocessing 1.1.2
After changing to this versions everything works fine:
keras 2.4.3
keras-preprocessing 1.1.0
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.