繁体   English   中英

TensorFlow Tensor在numpy argmax与keras argmax中的处理方式不同

[英]TensorFlow Tensor handled differently in numpy argmax vs keras argmax

为什么TensorFlow张量在Numpy中的数学函数中的表现与在Keras中的数学函数中表现不同?

当与TensorFlow Tensor处于相同的情况时,Numpy数组似乎正常运行。

这个例子表明在numpy函数和keras函数下正确处理numpy矩阵。

import numpy as np
from keras import backend as K

arr = np.random.rand(19, 19, 5, 80)

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

输出(如预期的那样)

np_argmax shape:  (19, 19, 5)
np_max shape:  (19, 19, 5)
k_argmax shape:  (19, 19, 5)
k_max shape:  (19, 19, 5)

与此示例相反

import numpy as np
from keras import backend as K
import tensorflow as tf

arr = tf.constant(np.random.rand(19, 19, 5, 80))

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

哪个输出

np_argmax shape:  ()
np_max shape:  (19, 19, 5, 80)
k_argmax shape:  (19, 19, 5)
k_max shape:  (19, 19, 5)

您需要执行/运行代码(例如在TF会话下)以评估张量。 在此之前,不评估张量的形状。

TF文档说:

Tensor中的每个元素都具有相同的数据类型,并且数据类型始终是已知的。 形状(即,它具有的尺寸数量和每个尺寸的大小)可能只是部分已知。 如果输入的形状也是完全已知的,大多数操作都会生成已知形状的张量,但在某些情况下,只能在图形执行时找到张量的形状。

为什么不为第二个示例尝试以下代码:

import numpy as np
from keras import backend as K
import tensorflow as tf

arr = tf.constant(np.random.rand(19, 19, 5, 80))
with tf.Session() as sess:
    arr = sess.run(arr)

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

arr = tf.constant(np.random.rand(19, 19, 5, 80))arr的类型是tf.Tensor ,但是在运行arr = sess.run(arr)它的类型将变为numpy.ndarray

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM