简体   繁体   中英

How to use tf.equal() with a label tensor of the shape=(1, 1) in TensorFlow?

I'm trying to evaluate individual images. It works so far. I get the individual probabilities for each class and the correct label. But when I try to get the class with tf.argmax(label, 1) I always get the class "0".

...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
print(pre)
# [[2.0451562e-06 # class 0
# 6.1964911e-06   # class 1
# 4.1852250e-06   # class 2
# 9.9847549e-01   # class 3 - We have a winner :)
# 8.2492170e-07   # class 4
# 3.1969071e-06   # class 5
# 1.5037126e-03   # class 6
# 1.6847488e-07   # class 7
# 6.7177882e-07   # class 8
# 3.4959594e-06]] # class 9
print(lbl)
# [[3]]
print(a_log)
# [3]
print(a_lbl)
# [0] # Why i dont get "3"?
...

I always get "0" for every data point. I would like to continue working with tf.equal() but with the wrong argmax value for the label, of course this is not possible. Any ideas?:

...
image, label = ...
logits = model(image)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1) # What must i change here?
cor_pre = tf.equal(arg_log, arg_lbl)
...

Edit

I get every time "0" cause i get the index! I Change the Question from: How to use tf.argmax() on a Tensor with the shape=(1, 1) in TensorFlow? to: How to use tf.equal() with a label tensor of the shape=(1, 1) in TensorFlow?

Based on the documentation , tf.argmax() takes an input , and an axis , among other parameters.

If your label has shape [1,1], what do you expect to get from an argmax across axis 1? There is only one entry.

Most likely, you want to compare the label to the argmaxed result. So:

...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
cor_pre = tf.equal(arg_log, tf.cast(label, tf.int64))

Any array of shape (1, 1) will contain exactly one element. That one element must necessarily be the maximum element in the array.

I found a solution, cast the label to int64 befor use it in tf.equal() .

...
image, label = ...
logits = model(image)
cor_pre = tf.equal(tf.argmax(logits, 1), tf.cast(label, tf.int64))
...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM