简体   繁体   中英

logits doesn't have the proper shape in vgg16 model for x-ray classification

I got the same error as here .

I have two classes, so the rank of labels should be one -- which is correct-- and rank of logits should be two -- [bachsize, #classes]-- which is 4 here-- [? 3,3,2]. In that link mentioned using flatten() on convolutional layer(s). Which layers do you apply flatten() on? all the convolutions? or only the last one in the network -- I mean fc8 --. Here's what I have in vgg.py:

with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with arg_scope(
    [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d],
    outputs_collections=end_points_collection):
  net = layers_lib.repeat(
      inputs, 2, layers.conv2d, 64, [3, 3], scope='conv1')
  net = layers_lib.max_pool2d(net, [2, 2], scope='pool1')
  net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2')
  net = layers_lib.max_pool2d(net, [2, 2], scope='pool2')
  net = layers_lib.repeat(net, 3, layers.conv2d, 256, [3, 3], scope='conv3')
  net = layers_lib.max_pool2d(net, [2, 2], scope='pool3')
  net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv4')
  net = layers_lib.max_pool2d(net, [2, 2], scope='pool4')
  net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv5')
  net = layers_lib.max_pool2d(net, [2, 2], scope='pool5')
  # Use conv2d instead of fully_connected layers.
  net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
  net = layers_lib.dropout(
      net, dropout_keep_prob, is_training=is_training, scope='dropout6')
  net = layers.conv2d(net, 4096, [1, 1], scope='fc7')
  net = layers_lib.dropout(
      net, dropout_keep_prob, is_training=is_training, scope='dropout7')
  net = layers.conv2d(
      net,
      num_classes, [1, 1],
      activation_fn=None,
      normalizer_fn=None,
      scope='fc8')

  # Convert end_points_collection into a end_point dict.
  end_points = utils.convert_collection_to_dict(end_points_collection)
  if spatial_squeeze:
    net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed')
    end_points[sc.name + '/fc8'] = net
  return net, end_points

This is where I get the error in nn_ops.py:

  with ops.name_scope(name, "SparseSoftmaxCrossEntropyWithLogits",
                      [labels, logits]):
    labels = ops.convert_to_tensor(labels)
    logits = ops.convert_to_tensor(logits)

    precise_logits = math_ops.cast(logits, dtypes.float32) if (
        dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits

    # Store label shape for result later.
    labels_static_shape = labels.get_shape()
    print(labels_static_shape)
    labels_shape = array_ops.shape(labels)

    print("labels_shape:", labels_shape)
    print("logits.get_shape().ndims:" ,(logits.get_shape().ndims))

    if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
      raise ValueError("Logits cannot be scalars - received shape %s." %
                       logits.get_shape())
    if logits.get_shape().ndims is not None and (
        labels_static_shape.ndims is not None and
        labels_static_shape.ndims != logits.get_shape().ndims - 1):
        raise ValueError("Rank mismatch: Rank of labels (received %s) should "
                       "equal rank of logits minus 1 (received %s)." %
                       (labels_static_shape.ndims, logits.get_shape().ndims))

And this is my error:

  File "/home/smolaei/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/nn_ops.py", line 1695, in sparse_softmax_cross_entropy_with_logits
    (labels_static_shape.ndims, logits.get_shape().ndims))
ValueError: Rank mismatch: Rank of labels (received 1) should equal rank of logits minus 1 (received 4).

Any recommendation on how to fix this?

Typically you'd flatten just before your output layers (your convolutions need non-flat input). See the TensorFlow MNIST model for example.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM