简体   繁体   中英

Dealing with unknown dimensions in Tensorflow

I wrote a function that calculates the gram matrix for image features of shape (1, H, W, C). Method I wrote is below:

def calc_gram_matrix(features, normalize=True):
  #input: features is a tensor of shape (1, Height, Width, Channels)

  _, H, W, C = features.shape
  matrix = tf.reshape(features, shape=[-1, int(C)])
  gram = tf.matmul(tf.transpose(matrix), matrix)

  if normalize:
    tot_neurons = H * W * C
    gram = tf.divide(gram,tot_neurons)

return gram

To Test my implementation of the gram matrix, There is a method:

 def gram_matrix_test(correct):
    gram = calc_gram_matrix(model.extract_features()[5])     #
    student_output = sess.run(gram, {model.image: style_img_test})
    print(style_img_test.shape)
    error = rel_error(correct, student_output)
    print('Maximum error is {:.3f}'.format(error))

 gram_matrix_test(answers['gm_out'])

When I run gram_matrix_test() I get an error -> ValueError: Cannot convert an unknown Dimension to a Tensor: ?

(The error is on this line -> " gram = tf.divide(gram,tot_neurons) " )

On debugging I found out that the shape of model.extract_features()[5] is (?, ?, ?, 128) and hence the division is not possible.

Dimensions of style_img_test are ((1, 192, 242, 3)), so when we run the session H,W,C will get populated.

Can you please guide me on how to fix this?

I made the following changes and it worked.

def calc_gram_matrix(features, normalize=True):
  #input: features is a tensor of shape (1, Height, Width, Channels)

  features_shape = tf.shape(features)
  H = features_shape[1]
  W = features_shape[2]
  C = features_shape[3]

  matrix = tf.reshape(features, shape=[-1, C])
  gram = tf.matmul(tf.transpose(matrix), matrix)

  if normalize:
    tot_neurons = H * W * C
    tot_neurons = tf.cast(tot_neurons, tf.float32)

    gram = tf.divide(gram,tot_neurons)

return gram

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM