[英]TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8
I have a dataset with 5 labels我有一个带有 5 个标签的数据集
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# resize the image to the desired size
return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
If I change this code with other dataset having 2 labels, class_names = ['dog', 'cat']
I find this error TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64
So how I can update def get_label(file_path)
如果我使用具有 2 个标签的其他数据集更改此代码,则class_names = ['dog', 'cat']
我发现此错误TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64
那么我如何更新def get_label(file_path)
I was having the same problem.我遇到了同样的问题。 Following the idea of the last post:继上个帖子的思路:
one_hot = tf.dtypes.cast(parts[-2] == class_names, dtype = tf.int16)
My guess would be that tf.argmax requires one of these data-types (I can't test this right now)我的猜测是 tf.argmax 需要这些数据类型之一(我现在无法测试)
float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64
so all you need to do is convert the output of所以你需要做的就是转换输出
one_hot = parts[-2] == class_names
to int, the "==" evaluates to True/False which is probably not allowed.对于 int,“==”的计算结果为 True/False,这可能是不允许的。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.