[英]How to substract two tensor with mask in Tensorflow?
我正在實施YOLO網絡,自我損失。
說有兩個張量,GT和PD(基礎事實和預測)。兩個4x4的dims矩陣。
假設GT是:
0,0,0,0
0,1,0,0
0,0,1,0
0,0,0,0
PD具有相同的大小和一些隨機的nums。
在這里,我需要分別計算均方誤差。
使用GT中的一個來計算MSE,並且在GT中單獨使用零點來計算MSE。
我更喜歡使用掩碼來覆蓋不相關的元素,因此計算時只計算相關元素。 我已經在numpy中實現了這個,但是不知道如何使用tf(v1.14)
import numpy as np
import numpy.ma as ma
conf = y_true[...,0]
conf = np.expand_dims(conf,-1)
conf_pred = y_pred[...,0]
conf_pred = np.expand_dims(conf_pred,-1)
noobj_conf = ma.masked_equal(conf,1) #cover grid with objects
obj_conf = ma.masked_equal(conf,0) #cover grid without objects
loss_obj = np.sum(np.square(obj_conf - conf_pred))
loss_noobj = np.sum(np.square(noobj_conf - conf_pred))
有關如何在tensorflow中實現此功能的任何建議?
如果我理解正確,你想分別計算0和1的均方誤差。
您可以執行以下操作:
y_true = tf.constant([[0,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,0]], dtype=tf.float32)
y_pred = tf.random.uniform([4, 4], minval=0, maxval=1)
# find indices where 0 is present in y_true
indices0 = tf.where(tf.equal(y_true, tf.zeros([1.])))
# find indices where 1 is present in y_true
indices1 = tf.where(tf.equal(y_true, tf.ones([1.])))
# find all values in y_pred which are present at indices0
y_pred_indices0 = tf.gather_nd(y_pred, indices0)
# find all values in y_pred which are present at indices1
y_pred_indices1 = tf.gather_nd(y_pred, indices1)
# mse loss calculations
mse0 = tf.losses.mean_squared_error(labels=tf.gather_nd(y_true, indices0), predictions=y_pred_indices0)
mse1 = tf.losses.mean_squared_error(labels=tf.gather_nd(y_true, indices1), predictions=y_pred_indices1)
# mse0 = tf.reduce_sum(tf.squared_difference(tf.gather_nd(y_true, indices0), y_pred_indices0))
# mse1 = tf.reduce_sum(tf.squared_difference(tf.gather_nd(y_true, indices1), y_pred_indices1))
with tf.Session() as sess:
y_, loss0, loss1 = sess.run([y_pred, mse0, mse1])
print(y_)
print(loss0, loss1)
輸出:
[[0.12770343 0.43467927 0.9362457 0.09105921]
[0.46243036 0.8838414 0.92655015 0.9347118 ]
[0.14018488 0.14527774 0.8395766 0.14391887]
[0.1209656 0.7793218 0.70543754 0.749542 ]]
0.341359 0.019614244
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.