简体   繁体   English

TensorFlow decode_csv形状错误

[英]TensorFlow decode_csv shape error

I read in a *.csv file using tf.data.TextLineDataset and apply map on it: 我使用tf.data.TextLineDataset*.csv文件中tf.data.TextLineDataset并在其上应用map

dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
                          num_parallel_calls=num_parallel_calls)

Parse function parse_record_fn looks like this: 解析函数parse_record_fn如下所示:

def parse_record(raw_record, is_training):
    default_record = ["./", -1]
    filename, label = tf.decode_csv([raw_record], default_record)
    # do something
    return image, label

But there raise an ValueError at tf.decode_csv in parse function: 但是在解析函数中的tf.decode_csv处引发了一个ValueError

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [1], [], [].

My *.csv file example: 我的*.csv文件示例:

/data/1.png, 5
/data/2.png, 7

Question : 问题

  1. Where goes wrong? 哪里出错了?
  2. What does shapes: [1], [], [] mean? shapes: [1], [], []是什么shapes: [1], [], []是什么意思?

Reproduce 复制

This error can be reproduced in this code: 此错误可以在此代码中重现:

import tensorflow as tf
import os

def parse_record(raw_record, is_training):
    default_record = ["./", -1]
    filename, label = tf.decode_csv([raw_record], default_record)

    # do something

    return image, label

with tf.Session() as sess:
    csv_path = './labels.txt'


    dataset = tf.data.TextLineDataset(csv_path)

    dataset = dataset.map(lambda value: parse_record(value, True))


sess.run(dataset)

Looking at the documentation of tf.decode_csv , it says about the default records: 查看tf.decode_csv的文档,它说明了默认记录:

record_defaults: A list of Tensor objects with specific types. record_defaults:具有特定类型的Tensor对象列表。 Acceptable types are float32, float64, int32, int64, string. 可接受的类型是float32,float64,int32,int64,string。 One tensor per column of the input record, with either a scalar default value for that column or empty if the column is required. 输入记录的每列一个张量,具有该列的标量默认值,如果需要该列,则为空。

I believe the error you are getting originates from how you define the tensor default_record . 我相信你得到的错误源于你如何定义张量default_record Your default_record certainly is a list of tensor objects (or objects convertible to tensors), but I think the error message is telling that they should be rank-1 tensors, not rank-0 tensors as in your case. 你的default_record肯定是张量对象(或可转换为张量的对象)的列表,但我认为错误消息告诉他们应该是rank-1张量,而不是你的情况下的rank-0张量。

You can fix the issue by making the default records rank 1 tensors. 您可以通过使默认记录排名为1张张来解决问题。 See the following toy example: 请参阅以下玩具示例:

import tensorflow as tf

my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now

decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
    d = sess.run(decoded_1)
    print(d)

# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)

The error produced on the last line is familiar: 最后一行产生的错误很常见:

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV_1' (op: 'DecodeCSV') with input shapes: [], [], []. ValueError:Shape必须为1级,但对于'DecodeCSV_1'(op:'DecodeCSV'),其输入形状为[],[],[]。

In the message, the input shapes, the three brackets [] , refer to the shapes of the input arguments records , record_defaults , and field_delim of tf.decode_csv . 在消息中,输入形状,三个方括号[] ,指的是record_defaults的输入参数recordsrecord_defaultsfield_delimtf.decode_csv In your case the first of these shapes is [1] since you input [raw_record] . 在您的情况下,自输入[raw_record]以来,这些形状中的第一个是[1] [raw_record] I agree that the message for this case is not very informative... 我同意这个案子的信息不是很有用......

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM