简体   繁体   中英

Keras lambda layer for l2 norm

I just want to implement a custom layer for taking the l2 norm of two vectors (of matching dimensions of course) which were output by 2 different models in keras. I'm using the functional API method of writing keras functions, so I have stuff like:

    inp1 = Input(someshape)
    X = Conv2D(someargs)(inp1)
    ...
    ...
    out1 = Dense(128)(X)

    inp2 = Input(someshape)
    Y = Conv2D(someargs)(inp2)
    ...
    ...
    out2 = Dense(128)(Y)

Then I want to take the l2 norm of the distance between out1 and out2 and feed it further into another network, so I have a lambda layer like:

    l2dist = keras.layers.Lambda(l2dist)(out1,out2)

Where l2dist is the function defined as:

    def l2dist(x,y):
        return K.sqrt(K.sum((x-y)**2))

But I get an error for the l2dist =... line saying:

    TypeError: __call__() takes 2 positional arguments but 3 were given

I clearly only put 2 arguments, out1 and out2, why does python think I'm giving 3 arguments?
I've tried this with a lambda function like:

    l2dist = keras.layers.Lambda(lambda x,y: K.sqrt(K.sum((x-y)**2)))(out1,out2)

But I get the same error.

I discovered that the Lambda layer in keras can only accept one argument as input, so I have to input the lambda function as a function on a list and pass the two tensors in as a list. I also realized that I can't use the l2 norm since that only gives me 1 number to run the final layers on, I have to use a different distance function that can give an element wise distance rather than a Euclidean distance between two vectors. I'm now using the chi-squared distance, so my code looks like this and it runs (but it's giving me nan as a loss, but that's a different issue I guess. At least it runs):

    chisqdist = keras.layers.Lambda(lambda x: (x[0]-x[1])**2/(x[0]+x[1]))([out1,out2])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM