简体   繁体   中英

Is there an easy way to compute Euclidean distance matrix in TensorFlow?

Is there an efficient way to compute the Euclidean distance matrix in TensorFlow, given either the Gram matrix or the coordinates?

The Euclidean distance matrix is an n-by-n matrix whose entries are given by the squared distance between each pair of points (x,y,z). The Gram matrix is simply the matrix of inner products. So if X is a 3-by-n matrix whose columns are the points, then the Gram matrix is given by X^T@X. There is a simple formula to convert from the Gram matrix to the distance squared matrix.

In other words, is there an easy way to write the following numpy/scipy functions in TensorFlow (v1 in particular)? If so, which one would be more efficient?

# Compute Euclidean distance matrix from Gram matrix
def euclidean_dist_from_gram(gram):
    temp = np.tile(np.diag(gram), (nodes,1))
    return temp + temp.T - 2*gram
# Compute Euclidean distance matrix from coordinates
def euclidean_dist_from_pts(xyz):
    return spt.distance.squareform(spt.distance.pdist(xyz.T, 'sqeuclidean'))

EDIT: The above two functions can be run in numpy and scipy by first importing the following dependency and generating random data:

import scipy.spatial as spt
xyz = np.random.uniform(-20, 20, (3,500))
gram = xyz.T@xyz

Figured it out, so I figured I'd post an answer to my own question. It's easiest to compute the Euclidean distance matrix from the Gram matrix, so here's the TensorFlow implementation (assuming a 3 xn coordinates matrix xyz). The thing I didn't realize is that I had to use tf.expand_dims in order to use tf.tile, and then it matches the numpy implementation.

gram = tf.linalg.matmul(tf.transpose(xyz), xyz)
g = tf.tile(tf.expand_dims(tf.linalg.diag_part(gram), 0), tf.constant([n,1]))
euclidean_dist = g + tf.transpose(g) - 2*gram

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM