简体   繁体   中英

How to use a pre-trained TensorFlow net in Keras loss function

I have a pre-trained net I want to use in order to evaluate loss in my Keras net. The pre-trained network was trained using TensorFlow and I just want to use it as part of my loss calculation.

The code of my custom loss function is currently:

def custom_loss_func(y_true, y_pred):
   # Get saliency of both true and pred
   sal_true = deep_gaze.get_saliency_map(y_true)
   sal_pred = deep_gaze.get_saliency_map(y_pred)

   return K.mean(K.square(sal_true-sal_pred))

Where deep_gaze is an object that is ment to manage the access to the external pre-trained net I am using.

It is defined this way:

class DeepGaze(object):
  CHECK_POINT = os.path.join(os.path.dirname(__file__), 'DeepGazeII.ckpt')  # DeepGaze II

def __init__(self):
    print('Loading Deep Gaze II...')

    with tf.Graph().as_default() as deep_gaze_graph:
        saver = tf.train.import_meta_graph('{}.meta'.format(self.CHECK_POINT))

        self.input_tensor = tf.get_collection('input_tensor')[0]
        self.log_density_wo_centerbias = tf.get_collection('log_density_wo_centerbias')[0]

    self.tf_session = tf.Session(graph=deep_gaze_graph)
    saver.restore(self.tf_session, self.CHECK_POINT)

    print('Deep Gaze II Loaded')

'''
Returns the saliency map of the input data. 
input format is a 4d array [batch_num, height, width, channel]
'''
def get_saliency_map(self, input_data):
    log_density_prediction = self.tf_session.run(self.log_density_wo_centerbias,
                                                 {self.input_tensor: input_data})

    return log_density_prediction

When I run this I get the error:

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.

What am I doing wrong? Is there a way to evaluate a net on a TensorFlow object coming for a different net (that was made by Keras with a TensorFlow backend).

Thanks in advance.

There are two main problems:

  • When you call get_saliency_map with input_data=y_true you are feeding a tensor input_data to another tensor self.input_tensor , and this is not valid. Moreover, these tensors do not hold a value at graph creation time, but rather they define a computation that will eventually produce a value.

  • Even if you could get an output from get_saliency_map , your code would still not work because this function disconnects your TensorFlow graph (it doesn't return a tensor), and all the logic must reside within the graph. Each tensor has to be computed based on the other available tensors in the graph.

The solution to this problem is to define the model producing self.log_density_wo_centerbias within the graph where you define your loss function, using the tensors y_true and y_pred directly as input without disconnecting the graph.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM