繁体   English   中英

Tensorflow的decode_csv仅读取一行

[英]Tensorflow's decode_csv only reading one line

我怎样才能得到decode_csv函数来读取CSV中的每一行?

我目前正在尝试将CS​​V文件中的数据加载到GPU中。 数据可以很好地加载到GPU上,除了...实际读取的640行CSV文件中只有一行。 您认为我在哪里错了?

import tensorflow as tf

with tf.device('/gpu:0'):
    filename_queue = tf.train.string_input_producer(['dataset.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    record_defaults = [['']]*121
    all_columns = tf.decode_csv(value, record_defaults=record_defaults)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        # Iterate through all the columns
        vals = []
        for x in range(121):
            tmp = all_columns.pop()
            myval = tmp.eval(session=sess)
            vals.append(myval)

        coord.request_stop()
        coord.join(threads)

那我如果...

>>> import numpy as np
>>> vals = np.asarray(vals)
>>> vals.shape
(121,)

我的CSV中的640行中的每行确实有121列。 vals的值对我来说看起来不错,除了我实际上并没有全部读取640行。 我猜它与:

all_columns = tf.decode_csv(value, record_defaults=record_defaults)

NVM。 弄清楚了。

显然,就如何提取行数据而言, sess.run()pop()之间存在差异。

我碰巧在CSV文件和121列中有640行,因此:

record_defaults = [['']]*121

for x in range(640):

请注意,这大多是硬编码的,仅用于测试目的。 解决方案如下:

import tensorflow as tf

with tf.device('/gpu:0'):
filename_queue = tf.train.string_input_producer(['../Datasets/CMU_face_images_dataset.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    record_defaults = [['']]*121
    all_columns = tf.decode_csv(value, record_defaults=record_defaults)

    # TWO NEW LINES
    name = all_columns[0]
    data = all_columns[1:]

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        vals = []
        names = []
        for x in range(640):

            # THIS IS THE NEW LINE
            _name, _val = sess.run([name, data])

            # OLD LINES
            # tmp = all_columns.pop()
            # myval = tmp.eval(session=sess)
            # vals.append(myval)

            names.append(_name)
            vals.append(_val)

        coord.request_stop()
        coord.join(threads)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM