简体   繁体   English

Keras:输出wrt输入的梯度作为分类器的输入

[英]Keras : Gradients of output w.r.t. input as input to classifier

I am doing research and for an experiment I want to use gradients of a specific layer in the network with respect to the network's input( similar as guided backprop) as input to another network (classifier). 我正在做研究,为了进行实验,我想使用网络中特定层相对于网络输入(类似于引导反向传播)的渐变作为另一个网络(分类器)的输入。 The goal is to 'force' network to change 'attention' according to classifier, so those two networks should be trained simultaneously. 目标是根据分类器“强制”网络更改“注意力”,因此应同时训练这两个网络。

I implemented it on this way : 我是这样实现的:

input_tensor = model.input
output_tensor = model.layers[-2].output
grad_calc = keras.layers.Lambda(lambda x:K.gradients(x,input_tensor)[0],output_shape=(256,256,3),trainable=False)(output_tensor)
pred = classifier(grad_calc)
out_model = Model(input_tensor,pred)

out_model.compile(loss='mse',optimizer=keras.optimizers.Adam(0.0001),metrics=['accuracy']) out_model.compile(loss ='mse',optimizer = keras.optimizers.Adam(0.0001),metrics = ['accuracy'])

Then, when I try to train the model 然后,当我尝试训练模型时

out_model.train_on_batch(imgs,np.zeros((imgs.shape[0],2)))

it is not working. 它不起作用。 It seems that it stucks there, nothing is happening (no error nor other message). 似乎卡住了,什么也没发生(没有错误或其他消息)。

I am not sure is this right way to implement this, so I would be very thankful if someone with more experience can take a look and give me advice. 我不确定实施此方法的正确方法,所以如果有更多经验的人可以看一下并给我建议,我将非常感激。

If I was trying to achieve that I would swith to plain Tensorflow and something along the lines: 如果我想实现这一目标,我会屈服于简单的Tensorflow和类似的东西:

#build model
input = tf.placeholder()
net   = tf.layesr.conv2d(input, 12)
loss  = tf.nn.l2_loss(net)
step  = tf.train.AdamOptimizer().minimize(loss)

# now inspect your graph and select the gradient tensor you are looking for
for op in tf.get_default_graph.get_operations():
    print(op.name)
grad = tf.get_default_graph().get_operation_by_name("enqueue")

with tf.Session as sess:
    _, grad, input = sess.run([step, grad, input], ...)
    # feed your grad and input into another network

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM