[英]Tensorflow's decode_csv only reading one line
我怎样才能得到decode_csv
函数来读取CSV中的每一行?
我目前正在尝试将CSV文件中的数据加载到GPU中。 数据可以很好地加载到GPU上,除了...实际读取的640行CSV文件中只有一行。 您认为我在哪里错了?
import tensorflow as tf
with tf.device('/gpu:0'):
filename_queue = tf.train.string_input_producer(['dataset.csv'])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['']]*121
all_columns = tf.decode_csv(value, record_defaults=record_defaults)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Iterate through all the columns
vals = []
for x in range(121):
tmp = all_columns.pop()
myval = tmp.eval(session=sess)
vals.append(myval)
coord.request_stop()
coord.join(threads)
那我如果...
>>> import numpy as np
>>> vals = np.asarray(vals)
>>> vals.shape
(121,)
我的CSV中的640行中的每行确实有121列。 vals
的值对我来说看起来不错,除了我实际上并没有全部读取640行。 我猜它与:
all_columns = tf.decode_csv(value, record_defaults=record_defaults)
NVM。 弄清楚了。
显然,就如何提取行数据而言, sess.run()
和pop()
之间存在差异。
我碰巧在CSV文件和121列中有640行,因此:
record_defaults = [['']]*121
和
for x in range(640):
请注意,这大多是硬编码的,仅用于测试目的。 解决方案如下:
import tensorflow as tf
with tf.device('/gpu:0'):
filename_queue = tf.train.string_input_producer(['../Datasets/CMU_face_images_dataset.csv'])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['']]*121
all_columns = tf.decode_csv(value, record_defaults=record_defaults)
# TWO NEW LINES
name = all_columns[0]
data = all_columns[1:]
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
vals = []
names = []
for x in range(640):
# THIS IS THE NEW LINE
_name, _val = sess.run([name, data])
# OLD LINES
# tmp = all_columns.pop()
# myval = tmp.eval(session=sess)
# vals.append(myval)
names.append(_name)
vals.append(_val)
coord.request_stop()
coord.join(threads)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.