简体   繁体   中英

TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8

I have a dataset with 5 labels

def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)

If I change this code with other dataset having 2 labels, class_names = ['dog', 'cat'] I find this error TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64 So how I can update def get_label(file_path)

I was having the same problem. Following the idea of ​​the last post:

one_hot = tf.dtypes.cast(parts[-2] == class_names, dtype = tf.int16)

My guess would be that tf.argmax requires one of these data-types (I can't test this right now)

float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64

so all you need to do is convert the output of

one_hot = parts[-2] == class_names

to int, the "==" evaluates to True/False which is probably not allowed.

I think in this line img = tf.io.read_file(file_path) , img is image name instead of actual image. To resolve this problem you can refer to here

It worked for me. Let me know!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM