[英]How to construct square of pairwise difference for each row vector in a matrix in tensorflow?
I have a 2D tensor having K*N dimension in TensorFlow, 我在TensorFlow中有一个具有K * N维的2D张量,
For each row vector in the tensor, having N dimension, I can calculate the square of pairwise difference using the approach in How to construct square of pairwise difference from a vector in tensorflow? 对于张量中具有N维的每个行向量,我可以使用如何从张量流中的向量构造成对差平方的方法来计算成对差平方。
However, I need to average the results of the K row vectors: performing each vector's square of pairwise difference and averaging the results. 但是,我需要对K个行向量的结果求平均:执行每个向量的成对差平方,并对结果求平均。
How can I do? 我能怎么做? Need your help, many thanks!!! 需要您的帮助,非常感谢!!!
The code and then the run results: 代码然后运行结果:
a = tf.constant([[1,2,3],[2,5,6]])
a = tf.expand_dims(a,1)
at = tf.transpose(a, [0,2,1])
pair_diff = tf.matrix_band_part( a - at, 0, -1)
output = tf.reduce_sum(tf.square(pair_diff), axis=[1,2])
final = tf.reduce_mean(output)
with tf.Session() as sess:
print(sess.run(a - at))
print(sess.run(output))
print(sess.run(final))
Give this results: 给出以下结果:
1) a - at
(computes the same thing of the link you posted but rowise) 1) a - at
(计算与您发布的链接相同但包含rowise的内容)
[[[ 0 1 2]
[-1 0 1]
[-2 -1 0]]
[[ 0 3 4]
[-3 0 1]
[-4 -1 0]]]
2) output
(take the matrix band part and sum all dimensions apart from rows, ie you have the result of the code you posted for each row) 2) output
(取矩阵带部分并求和除行以外的所有维,即,您得到每行发布的代码的结果)
[ 6 26]
3) final
Average among rows 3)行之间的final
平均值
16
Similar logic to How to construct square of pairwise difference from a vector in tensorflow? 与如何从张量流中的向量构造成对差平方的逻辑相似? but some changes required to handle 2d: 但是要处理2d,需要进行一些更改:
a = tf.constant([[1,2,3], [4, 6, 8]])
pair_diff = tf.transpose(a[...,None, None,] - tf.transpose(a[...,None,None,]), [0,3,1,2])
reshape_diff = tf.reshape(tf.matrix_band_part(pair_diff, 0, -1), [-1, tf.shape(a)[1]*tf.shape(a)[1]])
output = tf.reduce_sum(tf.square(reshape_diff),1)[::tf.shape(a)[0]+1]
with tf.Session() as sess:
print(sess.run(output))
#[ 6 24]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.