繁体   English   中英

InvalidArgumentError(请参见上面的回溯):重整形的输入是具有768个值的张量,但请求的形状为3072

[英]InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 768 values, but the requested shape has 3072

错误:

Caused by op 'Reshape', defined at:
File "train.py", line 72, in <module>
    tf.app.run()
File "/home/fzs/anaconda3/lib/python3.6/site- packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "train.py", line 69, in main
    train()
File "train.py", line 41, in train
    logit = inference.inference(image_batch, True, regularizer)
File "/home/fzs/Codes/Fisrt_for_test/inference.py", line 67, in inference
    reshaped = tf.reshape(pool3, [-1, nodes])
File "/home/fzs/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 2451, in reshape
name=name)
File "/home/fzs/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/home/fzs/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/fzs/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 768 values, but the requested shape has 3072
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](DecodeRaw, Reshape/shape)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[-1,32,32,3], [-1]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

我的火车代码:

train_files = tf.train.match_filenames_once(INPUT_DIR)
datasets = tf.data.TFRecordDataset(train_files)
datasets = datasets.map(parser).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
datasets = datasets.repeat(EPOCH)
iterator = datasets.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
print(image_batch.shape)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)

logit = inference.inference(image_batch, True, regularizer)
loss = calc_loss(logit, label_batch)
global_step = tf.Variable(0, trainable=False)

variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variable_averages_op = variable_averages.apply(tf.trainable_variables())

learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, DATA_NUM / BATCH_SIZE, LEARNING_RATE_DECAY, staircase=False)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
train_op = tf.group(variable_averages_op, train_step)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run(iterator.initializer)
    while True:
        try:
            _, loss_value, step = sess.run([train_op, loss, global_step])
            if step % 5000 == 0:
                print("After %d training step(s), loss on training batch is %g" % (step, loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_SAVE_NAME), global_step=global_step)
        except tf.errors.OutOfRangeError:
            print('Train done')
            break

我的数据集是使用cifar-10构建的。 我检查了image_batch的形状,它是(?,32,32,3)

我的推断代码:

conv1 = conv(input_tensor, [CONV1_SIZE, CONV1_SIZE, IMAGE_CHANNEL, CONV1_DEEP], C1_STRIDE, 'SAME', 'layer1-conv1')
pool1 = maxpool(conv1, POOL1_SIZE, P1_STRIDE, 'SAME', 'layer2-pool1')
conv2 = conv(pool1, [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP], C2_STRIDE, 'SAME', 'layer3-conv2')
pool2 = maxpool(conv2, POOL2_SIZE, P2_STRIDE, 'SAME', 'layer4-pool2')
conv3 = conv(pool2, [CONV3_SIZE, CONV3_SIZE, CONV2_DEEP, CONV3_DEEP], C3_STRIDE, 'SAME', 'layer5-conv3')
pool3 = maxpool(conv3, POOL3_SIZE, P3_STRIDE, 'SAME', 'layer6-pool3')

pool_shape = pool3.get_shape().as_list()
nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]    
reshaped = tf.reshape(pool3, [-1, nodes])

fc1 = fc(reshaped, [nodes, FC_SIZE], regularizer, 'layer7-fc1', False)
if train:
    fc1 = tf.nn.dropout(fc1, 0.5)
fc2 = fc(fc1, [FC_SIZE, CLASS_NUM], regularizer, 'layer8-fc2', True)

我调试时的变量如下所示:

input_tensor.shape
TensorShape([Dimension(None), Dimension(32), Dimension(32), Dimension(3)])

conv1.shape
TensorShape([Dimension(None), Dimension(32), Dimension(32), Dimension(64)])

pool1.shape
TensorShape([Dimension(None), Dimension(16), Dimension(16), Dimension(64)])

conv2.shape
TensorShape([Dimension(None), Dimension(16), Dimension(16), Dimension(128)])

pool2.shape
TensorShape([Dimension(None), Dimension(8), Dimension(8), Dimension(128)])

conv3.shape
TensorShape([Dimension(None), Dimension(8), Dimension(8), Dimension(256)])

pool3.shape
TensorShape([Dimension(None), Dimension(4), Dimension(4), Dimension(256)])

我注意到3072 = 32 * 32 * 3,但是为什么呢? 而且我不知道768是什么意思。

有人可以帮助我吗? 非常感谢你的帮助!

更新:

现在我不确定回溯是否正确。 因为当我继续调试时,调整后的variabel形状正确,并且没有错误中断调试。

reshaped.shape
TensorShape([Dimension(20), Dimension(4096)])

但是调试被此中断:

sess.run(tf.local_variables_initializer())

错误:

2018-09-19 22:25:20.055190: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at matching_files_op.cc:49 : Not found: FindFirstFile failed for: train_data 

似乎目录train_data不存在。 但是我有目录。

2018/09/19  22:12    <DIR>          .
2018/09/19  22:12    <DIR>          ..
2018/09/16  22:19             3,382 check_codes.py
2018/09/19  21:45             2,712 dataset_reader.py
2018/09/15  15:16                 0 eval.py
2018/09/19  21:59             5,148 inference.py
2018/09/16  16:20    <DIR>          Model
2018/09/16  16:28                67 preprocess.py
2018/09/15  16:55    <DIR>          test_data
2018/09/19  22:15             3,106 train.py
2018/09/15  16:53    <DIR>          train_data
2018/09/19  22:19    <DIR>          __pycache__

现在,我粘贴了完整的培训代码。

768来自16 * 16 * 3

3072来自32 * 32 * 3

16x16和32x32是图像矩阵的暗点; 和3是因为存在3个矩阵> RGB。 很难遵循代码,但是在某个地方应用了conv并将图像大小从16像素调整为32像素。

感谢所有查看我的问题的人。 现在我很喜欢它。 我用float32解码了tfrecord,它应该是uint8,所以输入不能匹配网络。 但是,仍然很奇怪,追溯没有显示出真正的错误。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM