简体   繁体   English

将keras.backend.argmax返回的张量作为索引传递给keras.backend,gather,该张量期望为“索引的整数张量”。

[英]Pass a tensor returned from keras.backend.argmax as indices to keras.backend,gather which expects 'An integer tensor of indices.'

I am trying to implement a custom loss function 我正在尝试实现自定义损失功能

def lossFunction(self,y_true,y_pred):

     maxi=K.argmax(y_true)

     return K.mean((K.max(y_true) -(K.gather(y_pred,maxi)))**2)

which give following error when training 训练时出现以下错误


InvalidArgumentError (see above for traceback): indices[5] = 51 is not in [0, 32) [[Node: loss/dense_3_loss/Gather = Gather[Tindices=DT_INT64, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dense_3/BiasAdd, metrics/acc/ArgMax)]] InvalidArgumentError(请参见上面的回溯):indexs [5] = 51不在[0,32]中[[节点:loss / dense_3_loss / Gather = Gather [Tindices = DT_INT64,Tparams = DT_FLOAT,validate_indices = true,_device =“ /作业:localhost /副本:0 /任务:0 /设备:CPU:0“](dense_3 / BiasAdd,metrics / acc / ArgMax)]]


model summary 模型总结


_________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 64, 50, 1)     0                                            
____________________________________________________________________________________________________
input_2 (InputLayer)             (None, 64, 50, 1)     0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 32, 25, 16)    272         input_1[0][0]                    
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (None, 32, 25, 16)    272         input_2[0][0]                    
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)   (None, 16, 12, 16)    0           conv2d_1[0][0]                   
____________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)   (None, 16, 12, 16)    0           conv2d_2[0][0]                   
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (None, 15, 11, 32)    2080        max_pooling2d_1[0][0]            
____________________________________________________________________________________________________
conv2d_4 (Conv2D)                (None, 15, 11, 32)    2080        max_pooling2d_2[0][0]            
____________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)   (None, 8, 6, 32)      0           conv2d_3[0][0]                   
____________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)   (None, 8, 6, 32)      0           conv2d_4[0][0]                   
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1536)          0           max_pooling2d_3[0][0]            
____________________________________________________________________________________________________
flatten_2 (Flatten)              (None, 1536)          0           max_pooling2d_4[0][0]            
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 3072)          0           flatten_1[0][0]                  
                                                                   flatten_2[0][0]                  
____________________________________________________________________________________________________
input_3 (InputLayer)             (None, 256)           0                                            
____________________________________________________________________________________________________
concatenate_2 (Concatenate)      (None, 3328)          0           concatenate_1[0][0]              
                                                                   input_3[0][0]                    
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 512)           1704448     concatenate_2[0][0]              
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 256)           131328      dense_1[0][0]                    
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 256)           65792       dense_2[0][0]                    
====================================================================================================
Total params: 1,906,272
Trainable params: 1,906,272
Non-trainable params: 0

Argmax is taking from the last axis, while gather is taking from the first. Argmax从最后一个轴取,而Gather从第一个轴取。 You don't have the same numbers of elements in both axes, so this is expected. 两个轴上的元素数量都不相同,因此这是可以预期的。

For working only on classes, use the last axis, so we are going to quirk around the gather method: 对于仅在类上工作的对象,请使用最后一个轴,因此我们将围绕collect方法进行古怪的操作:

def lossFunction(self,y_true,y_pred):

    maxi=K.argmax(y_true) #ok

    #invert the axes
    y_pred = K.permute_dimensions(y_pred,(1,0))

    return K.mean((K.max(y_true,axis=-1) -(K.gather(y_pred,maxi)))**2)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM