繁体   English   中英

自定义损失函数,它依赖于keras中的另一个神经网络

[英]Custom loss function which depends on another neural network in keras

我对keras有一个“我该怎么做”的问题:

假设我有第一个神经网络,比如说NNa,它有4个已经训练过的输入(x,y,z,t)。 如果我有第二个神经网络,例如NNb,那么它的损失函数取决于第一个神经网络。

NNb的自定义损失函数customLossNNb调用具有固定网格(x,y,z)的NNa的预测,只需修改最后一个变量t。

在这里,我要用伪python代码训练第二个NN:NNb:

grid=np.mgrid[0:10:1,0:10:1,0:10:1].reshape(3,-1).T

Y[:,0]=time
Y[:,1]=something

def customLossNNb(NNa,grid):
     def diff(y_true,y_pred): 
         for ii in range(y_true.shape[0]):
               currentInput=concatenation of grid and y_true[ii,0]
               toto[ii,:]=NNa.predict(currentInput)
               #some stuff with toto
         return #...
     return diff

然后

NNb.compile(loss=customLossNNb(NNa,K.variable(grid)),optimizer='Adam')
NNb.fit(input,Y)

实际上,引起我麻烦的那一行是currentInput=concatenation of grid and y_true[ii,0]

我试图使用K.variable(grid)将网格作为张量发送到customLossNNb。 但是我无法在损失函数中定义新的张量,例如CurrentY ,其形状为(grid.shape[0],1)填充y[ii,0]当前t),然后连接gridcurrentY建立currentInput

有任何想法吗?

谢谢

您可以使用keras的功能API将自定义损失函数包括在图中。 在这种情况下,该模型可以用作函数,如下所示:

for l in NNa.layers: 
    l.trainable=False
x=Input(size)
y=NNb(x)
z=NNa(y)

预测方法将不起作用,因为损失函数应该是图形的一部分,并且预测方法返回np.array

首先,使NNa不可训练。 请注意,如果您的模型具有内部模型,则应递归执行此操作。

def makeUntrainable(layer):
    layer.trainable = False

    if hasattr(layer, 'layers'):
        for l in layer.layers:
            makeUntrainable(l)

makeUntrainable(NNa)

然后,您有两个选择:

  • 将NNa附加到模型的末尾(注意y_truey_pred都将被更改)
    • 然后更改目标(使用NNa进行预测)以获得正确的结果,因为您的模型现在期望使用NNa而不是NNb的输出。
  • 创建一个在其中使用NNa的自定义损失函数,而无需更改目标

选项1-附加模型

inputs = NNb.inputs   
outputs = NNa(NNb.outputs) #make sure NNb is outputing 4 tensors to match NNa inputs   
fullModel = Model(inputs,outputs)

#changing the targets:
newY_train = NNa.predict(oldY_train)    

选项2-创建自定义损失

警告:在训练此配置时,请测试NNa的重量是否真的冻结了

from keras.losses import binary_crossentropy

def customLoss(true,pred):
    true = NNa(true)
    pred = NNa(pred)

    #use some of the usual losses or create your own
    binary_crossentropy(true,pred)

NNb.compile(optimizer=anything, loss = customLoss)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM