[英]How do I return the value of a tensor from a function in TensorFlow?
我正在Keras開展深度學習項目,並使用TensorFlow后端實現了靈敏度功能,因為如果我想使用它來評估模型,則需要這樣做。 但是,我無法從張量中提取值。 我想返回它,以便我可以在其他函數中使用這些值。 理想情況下,返回值應為int
。 每當我評估函數時,我只得到張量對象本身,而不是它的實際值。
我試過創建一個會話和評估,但無濟於事。 我能夠以這種方式打印該值,但我無法將值分配給另一個變量。
def calculate_tp(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = K.equal(y, i)
preds = K.equal(y_pred, i)
TP += K.sum(K.cast(tf.boolean_mask(preds, tf.math.equal(true, True)), 'int32'))
FP += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(~preds, True)), 'int32'))
TN += K.sum(K.cast(tf.boolean_mask(~preds, tf.math.equal(true, True)), 'int32'))
FN += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(preds, False)), 'int32'))
"""with tf.Session() as sess:
TP = TP.eval()
FP = FP.eval()
FN = FN.eval()
FP = FP.eval()
print(TP, FP, TN, FN)
#sess.run(FP)"""
return TP / (TP + FN)
如果我很好地理解你的問題,你可以簡單地創建一個新的張量,其值已被忽略。
例如 :
tensor = tf.constant([5, 5, 8, 6, 10, 1, 2])
tensor_value = tensor.eval(session=tf.Session())
print(tensor_value) #get [ 5 5 8 6 10 1 2]
new_tensor = tf.constant(tensor_value)
print(new_tensor) #get Tensor("Const_25:0", shape=(7,), dtype=int32)
希望我幫忙!
好吧,因為在你的嘗試中,TP總是0?
如果我嘗試:
y = np.array([0, 0, 0, 0, 0, 1, 1, 1 ,1 ,1])
y_pred = np.array([0.01, 0.005, 0.5, 0.09, 0.56, 0.999, 0.89, 0.987 ,0.899 ,1])
def calculate_tp(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = K.equal(y, i)
preds = K.equal(y_pred, i)
TP += K.sum(K.cast(tf.boolean_mask(preds, tf.math.equal(true, True)), 'int32'))
FP += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(~preds, True)), 'int32'))
TN += K.sum(K.cast(tf.boolean_mask(~preds, tf.math.equal(true, True)), 'int32'))
FN += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(preds, False)), 'int32'))
TP = TP.eval(session=tf.Session())
FP = FP.eval(session=tf.Session())
TN = TN.eval(session=tf.Session())
FN = FN.eval(session=tf.Session())
print(TP, FP, TN, FN)
results = TP / (TP + FN)
return results
res = calculate_tp(y, y_pred)
print(res)
#Outputs :
#0 5 5 5
#1 9 9 9
#1 9 9 9
#1 9 9 9
#1 9 9 9
#0.1
它給了我一個浮點數,就像你想要的那樣。
它有幫助嗎?
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.