繁体   English   中英

Tensorflow:InvalidArgumentError(回溯见上文):您必须使用dtype字符串输入占位符张量'arg0'的值

[英]Tensorflow: InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string

我需要将零填充到从TFRecord文件读取的张量(代码段中为“ ARRAY”)。 因为我使用的训练模型要求它们的形状应该相同。 但是,我的输入具有不同的宽度和长度。 因此,我尝试计算代码段中应填充的零数(“ paddings = tf.Variable([[0,targetLength],[0,targetWidth]])”)。 但是,tensorflow引发了InvalidArgumentError,并且在我的代码中从未出现过值“ arg0”。

变量ARRAY的示例如下:

[1,0,1,0]
[2,0,2,0]

我应该将其填充到

[1,0,1,0,0,0]
[2,0,2,0,0,0]
[0,0,0,0,0,0]

可能有很多大型阵列,所以我想在训练之前立即进行填充。 这是我的代码段。

def my_input_fn(file_path, perform_shuffle=True, repeat_count=1):
 global width, length # two int64 variables.
 batchNum = 32
 def parse_ARRAY(tfrecord):
     features = tf.parse_single_example(
             tfrecord,
             # Defaults are not specified since both keys are required.
             features={
                 'label': tf.FixedLenFeature([], tf.int64),
                 'length': tf.FixedLenFeature([], tf.int64),
                 'width': tf.FixedLenFeature([], tf.int64),
                 'ARRAY': tf.FixedLenFeature([], tf.string)
                 })

     ARRAY = tf.decode_raw(features['ARRAY'], tf.int64)
     label = tf.cast(features['label'], tf.int64)
     ARRAYLength = tf.cast(features['length'], tf.int64)
     ARRAYWidth = tf.cast(features['width'], tf.int64)
     ARRAYshape = tf.stack([ARRAYLength, ARRAYWidth])
     ARRAY = tf.reshape(ARRAY, ARRAYshape)
     TFWidth = tf.convert_to_tensor(width, tf.int64)
     TFLength = tf.convert_to_tensor(length, tf.int64)
     targetWidth = tf.subtract(TFWidth, ARRAYWidth)
     targetLength = tf.subtract(TFLength, ARRAYLength)
     paddings = tf.Variable([[0, targetLength],[0, targetWidth]])
     with tf.Session() as sess:
         sess.run(paddings.initializer)
     tf.pad(ARRAY, paddings, "CONSTANT")

     return {"ARRAY":ARRAY}, label

 dataset = tf.data.TFRecordDataset(file_path)
 dataset = dataset.map(parse_ARRAY)
 if perform_shuffle:
     dataset = dataset.shuffle(buffer_size=256)
 dataset = dataset.repeat(repeat_count)  # Repeats dataset this # time
 dataset = dataset.batch(batchNum)  # Batch size to use
 iterator = dataset.make_one_shot_iterator()
 batch_features, batch_labels = iterator.get_next()
 return  batch_features, batch_labels


def run_tfr(args):
     global length, width
     model = tf.estimator.Estimator(model_fn)
     model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)    
     e = model.evaluate(input_fn = lambda: my_input_fn(args[0]+".tests"))
     print("Testing Accuracy:", e['accuracy'])

 if __name__ == "__main__":
     width, length = load_data.loadInfo(sys.argv[1])
     # the usage is 'python thisfile.py file.pkl' or 
     # 'python thisfile.py file.tfrecord'
     if sys.argv[1].endswith(".pkl"):
         # handle a file from cPickle.
     elif sys.argv[1].endswith("tfrecord"):
         run_tfr(sys.argv[1:])

这是张量流的输出。

Caused by op u'arg0', defined at:
  File "DLCNN.py", line 230, in <module>
    run_tfr(sys.argv[1:])
  File "DLCNN.py", line 215, in run_tfr
    model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 708, in _train_model
    input_fn, model_fn_lib.ModeKeys.TRAIN)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 577, in _get_features_and_labels_from_input_fn
    result = self._call_input_fn(input_fn, mode)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 663, in _call_input_fn
    return input_fn(**kwargs)
  File "DLCNN.py", line 215, in <lambda>
    model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)
  File "DLCNN.py", line 149, in my_input_fn
    dataset = dataset.map(parse_ARRAY)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 712, in map
    return MapDataset(self, map_func)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1385, in __init__
    self._map_func.add_to_graph(ops.get_default_graph())
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
    self._create_definition_if_needed()
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
    self._create_definition_if_needed_impl()
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 334, in _create_definition_if_needed_impl
    argholder = array_ops.placeholder(argtype, name=argname)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1599, in placeholder
    return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3091, in _placeholder
    "Placeholder", dtype=dtype, shape=shape, name=name)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 703, in create_op
    **kwargs)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
         [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

顺便说一句,如果将sess.run(paddings.initializer)替换为以下代码,tensorflow将不会报告错误:

node1 = tf.constant(3.0, dtype=tf.float32)
node2 = tf.constant(4.0)
node3 = tf.add(node1, node2)
print("sess.run(node3):", sess.run(node3))

我也想知道是否还有其他方法可以填充从TFRecord读取的数组。 谢谢。

这个答案的重点是“我还想知道是否还有其他方法可以填充从TFRecord读取的数组”。

我使用tf.shape提取数组的形状信息,然后填充数组。

这是代码。

def my_input_fn(file_path, perform_shuffle=True, repeat_count=1):
     global width, length
     batchNum = 32
     def parse_ARRAY(tfrecord):
         features = tf.parse_single_example(
                 tfrecord,
                 # Defaults are not specified since both keys are required.
                 features={
                     'label': tf.FixedLenFeature([], tf.int64),
                     'length': tf.FixedLenFeature([], tf.int64),
                     'width': tf.FixedLenFeature([], tf.int64),
                     'ARRAY': tf.FixedLenFeature([], tf.string)
                     })

         ARRAY = tf.decode_raw(features['ARRAY'], tf.int64)
         label = tf.cast(features['label'], tf.int64)
         ARRAYLength = tf.cast(features['length'], tf.int64)
         ARRAYWidth = tf.cast(features['width'], tf.int64)
         ARRAYshape = tf.stack([ARRAYLength, ARRAYWidth])
         ARRAY = tf.reshape(ARRAY, ARRAYshape)
         height = tf.shape(ARRAY)[0]
         imgWidth = tf.shape(ARRAY)[1]
         paddings = [[0, length - height],[0, width - imgWidth]]

         #ARRAY = array_ops.pad(ARRAY, paddings)
         ARRAY = tf.pad(ARRAY, paddings, "CONSTANT")
         return {"ARRAY":ARRAY}, label
         ......

实现了从TFRecord读取过程数组的目标。 但是,仍然不知道InvalidArgumentError的原因。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM