简体   繁体   中英

keras model with tf.contrib.losses.metric_learning.triplet_semihard_loss Assertion error

I am using python 3 with anaconda, and trying to use a tf.contrib loss function with a Keras model.

The code is the following

from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.contrib.losses import metric_learning
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(50,  activation="relu"))
model.compile(loss=metric_learning.triplet_semihard_loss, optimizer=Adam())

I get the following error:

File "/home/user/.local/lib/python3.6/site-packages/keras/engine/training_utils.py", line 404, in weighted score_array = fn(y_true, y_pred) File "/home/user/anaconda3/envs/siamese/lib/python3.6/site-packages/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py", line 179, in triplet_semihard_loss assert lshape.shape == 1 AssertionError

When I am using the same network with a keras loss function it works fine, I tried to wrap the tf loss function in a function like so

def func(y_true, y_pred): 
    import tensorflow as tf
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(y_true, y_pred) 

And still getting the same error

What am I doing wrong here?

update: When changing the func to return the following

return K.categorical_crossentropy(y_true, y_pred)

everything works fine! But i cant get it to work with the specific tf loss function...

When i go into tf.contrib.losses.metric_learning.triplet_semihard_loss and remove this line of code: assert lshape.shape == 1 it runs fine

Thanks

The problem is that you pass wrong input to the loss function.

According to triplet_semihard_loss docstring you need to pass labels and embeddings .

So your code have to be:

def func(y, embeddings): 
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=y, embeddings=embeddings) 

And two more notes about network for embeddings:

  1. Last dense layer has to be without activation

  2. Don't forget to normalise output vector model.add(Lambda(lambda x: K.l2_normalize(x, axis=1)))

It seems that your problem comes from an incorrect input in the loss function. In fact, the triplet loss wants the parameters:

Args:
labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
  multiclass integer labels.
embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
  be l2 normalized.

Are you sure that y_true has the correct shape? Can you give us more details about the tensors you are using?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM