简体   繁体   中英

Euclidean distance transform in tensorflow

I would like to create a tensorflow function, that replicates the euclidean distance transform of scipy for each 2-dimensional matrix in my 3-dimensional tensor.

I have a 3-dimensional tensor, where the third axis is representing a one-hot encoded feature. I would like to create for each feature dimension a matrix, where the values in each cell equal the distance to the nearest feature.

Example:

input = [[1 0 0]
         [0 1 0]
         [0 0 1],

         [0 1 0]
         [0 0 0]
         [1 0 0]]

output = [[0    1   1.41]
          [1    0   1   ]
          [1.41 1   0   ],

          [1    0   1   ]
          [1    1   1.41]
          [0    1   2   ]]              

My current solution is implemented in python. The method iterates through every cell of a feature dimension, creates a ring around the cell and searches if the ring contains a feature. Then it calculates the distance for the cell to each feature entry and takes the minimum. If the ring does not contain a cell with a feature in it, the search ring gets wider.

Code:

import numpy as np
import math

def distance_matrix():
    feature_1 = np.eye(5)
    feature_2 = np.array([[0, 1, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [1, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],])
    ground_truth = np.stack((feature_1,feature_2), axis=2)
    x = np.zeros(ground_truth.shape)

    for feature_index in range(ground_truth.shape[2]):
        for i in range(ground_truth.shape[0]):
            for j in range(ground_truth.shape[1]):
                x[i,j,feature_index] = search_ring(i,j, feature_index,0,ground_truth)
    print(x[:,:,0])

def search_ring(i, j,feature_index, ring_size, truth):
    if ring_size == 0 and truth[i,j,feature_index] == 1.:
                    return 0
    else:
        distance = truth.shape[0]
        y_min = max(i - ring_size, 0)
        y_max = min(i + ring_size, truth.shape[0] - 1)
        x_min = max(j - ring_size, 0)
        x_max = min(j + ring_size, truth.shape[1] - 1)

        if truth[y_min:y_max+1, x_min:x_max+1, feature_index].sum() > 0:
            for y in range(y_min, y_max + 1):
                for x in range(x_min, x_max + 1):
                    if y == y_min or y == y_max or x == x_min or x == x_max:
                        if truth[y,x,feature_index] == 1.:
                            dist = norm(i,j,y,x,type='euclidean')
                            distance = min(distance, dist)
            return distance
        else:
            return search_ring(i, j,feature_index, ring_size + 1, truth)

def norm(index_y_a, index_x_a, index_y_b, index_x_b, type='euclidean'):
    if type == 'euclidean':
        return math.sqrt(abs(index_y_a - index_y_b)**2 + abs(index_x_a - index_x_b)**2)
    elif type == 'manhattan':
        return abs(index_y_a - index_y_b) + abs(index_x_a - index_x_b)


def main():
    distance_matrix()
if __name__ == '__main__':
    main()

My problem is replicating this in Tensorflow, since I need it for a custom loss function in Keras. How can I access the indices of the items I am iterating through?

I don't see any problem for you to use the distance transform in keras , basically, all you need is tf.py_func , which wraps an existing python function to a tensorflow operator.

However, I think the fundamental issue here is about the backpropagation. Your model will have any problem in the forward pass, but what gradient do you expect to propagate? Or you simply don't care its gradient at all.

I've done something similar with py_func to create a signed distance transform, using scipy . Here's what it might look like in your case:

import scipy.ndimage.morphology as morph
arrs = []
for channel_index in range(C):
    arrs.append(tf.py_func(morph.distance_transform_edt, [tensor[..., channel_index]], tf.float32))
edt_tf = tf.stack(arrs, axis=-1)

Note the limitations of py_func : they won't be serialized to GraphDefs , so it silently won't serialize the body of the function in the models you save. See the tf.py_func documentation .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM