简体   繁体   中英

Keras on-the-fly sample weights calculation in a custom loss function

I'm implementing a type of segmentation network in keras (with tf backend) where I want to weight the loss for each image. The weight will have the same shape as the target image - ie each pixel will be weighted differently in the loss. As far as I can tell, keras has no way to implement this weighting scheme natively. The weights can be calculated from the labels and storing the weights on disk is infeasible due to the size of the dataset. As such, I have started writing my own loss function that can calculate the weight matrix on the fly from the labels.

labels.shape == (?, 5, 101, 101) . The weight can be looked up from a dictionary based on axis 1. So for the first item in the batch, the entry for labels[0, :, 0, 0] is [1, 0, 0, 1, 0] , I would look up this value in a dictionary. If labels were a numpy array, I could do the following:

labels.astype(object).sum(axis=1).astype(str)

This would give me an array of shape (?, 101, 101) with entries like '10010' which I could lookup in a dictionary lookup . I could assign weights[0, :, 0, 0] = lookup['10010'] . Finally, with the weights tensor populated, I could get my loss as:

keras.backend.categorical_crossentropy(labels, predictions) * keras.backend.constant(weights) .

The problem is that I can't do keras.backend.eval(labels) to get the numpy array in my custom loss function. While compiling the model, the graph is constructed without any data being fed. Using eval at this point causes a InvalidArgumentError: You must feed a value for placeholder tensor error. Is there a way to do do the dtype conversion to string, the lookup in the dictionary and assignment of the weights tensor symbolically using keras operations?

Or is there a way I could workaround this issue and calculate the weights elsewhere in the code? I'm using a keras.utils.Sequence and calling fit_generator on the model for training - it is also possible to return the weights from the data generator.

I'm using: tensorflow 1.9.0 and keras 2.2.4

I'm happy to switch to more up-to-date versions of the packages.

Going to answer my own question here in case someone else has the same problem

tensorflow 2 has some functions that are quite useful for this problem:

  1. For converting a matrix of floats to strings - tf.strings.as_string
  2. For joining strings across a dimension - tf.strings.reduce_join
  3. For looking up values of a tensor in a dictionary called weight_dict
tensor_of_interest = tensor after using tf.strings.as_string and tf.strings.reduce_join

keys = tf.constant(list(weight_dict.keys()))
values = tf.constant(list(weight_dict.values()))
weight_table = tf.lookup.StaticHashTable(
   tf.lookup.KeyValueTensorInitializer(keys, values), -1
)
weights = weight_table.lookup(tensor_of_interest)

Happy coding

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM