简体   繁体   English

按每个数组比较 numpy arrays 的两个张量 - tensorflow

[英]Compare two tensors of numpy arrays by each array - tensorflow

I have a multilabel classification problem and my y_true and y_pred during training looks like this:我有一个多标签分类问题,训练期间我的y_truey_pred如下所示:

y_true = tf.constant([[0, 1, 1, 0], [0, 1, 1, 0]])
y_pred = tf.constant([[0, 1, 0, 1], [0, 1, 1, 0]])

I want to compare those two based on each pair of lists.我想根据每对列表比较这两者。 To do so, I wrote something like为此,我写了类似的东西

values = tf.cast(x, "float32") == tf.cast(y, "float32")
bool_to_number_values = tf.cast(tranformed_values, "float32")
print(bool_to_number_values)
tranformed_values_summed = x.numpy().shape[0] - tf.reduce_sum(bool_to_number_values)
tranformed_values_summed.numpy()

This returns这返回

tf.Tensor(
[[1. 1. 0. 0.]
 [1. 1. 1. 1.]], shape=(2, 4), dtype=float32)

and -4.0 because 2.0 - 6.0 == -4.0-4.0因为2.0 - 6.0 == -4.0

But I don't want this.但我不想要这个。 I want to compare the first array of y_true to the first array of y_pred and if they are identical return True else False .我想将 y_true 的第一个数组与y_true的第一个数组进行y_pred ,如果它们相同,则返回True否则False The same logic applies for the second array of y_true and y_pred .相同的逻辑适用于y_truey_pred的第二个数组。

So the correct result should be所以正确的结果应该是

tf.Tensor(
[0,
 1], , shape=(2,), dtype=float32)

#0: because the arrays on index 0 are not equal y_true[0] <> y_pred[0]
#1: because the arrays on index 1 are equal y_true[1] == y_pred[1] 

and the tranformed_values_summed.numpy() = 2.0 - 1.0 = 1.0tranformed_values_summed.numpy() = 2.0 - 1.0 = 1.0

I think you might be looking for tf.reduce_all :我想你可能正在寻找tf.reduce_all

tf.cast(tf.reduce_all(tf.equal(y_true, y_pred), axis=-1), tf.int32)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1])>

Copy/pastable:复制/粘贴:

import tensorflow as tf

y_true = tf.constant([[0, 1, 1, 0], [0, 1, 1, 0]])
y_pred = tf.constant([[0, 1, 0, 1], [0, 1, 1, 0]])

tf.cast(tf.reduce_all(tf.equal(y_true, y_pred), axis=-1), tf.int32)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM