简体   繁体   English

如何获得 w.r.t 的权重梯度到目标神经元?

[英]How to get gradients of weights w.r.t. to a target neuron?

I have found code online to get the derivative of the total loss with respect to the deep learning weights.我在网上找到了代码来获得关于深度学习权重的总损失的导数。 I am trying to find the derivative of the weights with respect to the loss of a single class instead of all classes.我试图找到权重关于单个 class 而不是所有类的损失的导数。

I used the following code to get the gradient of an input image with respect to the total loss.我使用以下代码来获取输入图像相对于总损失的梯度。 If I visualize it, it shows the importance of the pixels for all predictions.如果我将其可视化,它会显示像素对所有预测的重要性。 But, I would like to compute the derivative of the input image with respect to a particular class (eg "lady_bug").但是,我想计算输入图像相对于特定 class(例如“lady_bug”)的导数。 This should show the importance of the pixels for the prediction of lady_bug.这应该显示像素对于预测 lady_bug 的重要性。 Do you have an idea how I can do that?你知道我该怎么做吗?

from keras.applications.vgg19 import VGG19
import numpy as np
import cv2
from keras import backend as K
import matplotlib.pyplot as plt

from keras.applications.inception_v3 import decode_predictions


def get_model():
    model = VGG19(include_top=True, weights='imagenet')
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model


def predict(model, images):
    numeric_prediction = model.predict(images)
    categorical_prediction = decode_predictions(numeric_prediction, top=1)
    return [(x[0][1], x[0][2]) for x in categorical_prediction]


def get_test_image():
    # Image
    image_path = "lady_bug.jpg"
    image = cv2.imread(image_path)
    my_image = cv2.resize(image, (224,224))
    my_image = np.expand_dims(my_image, axis=0)
    return my_image


def visualize_sample(sample, file_path):
    plt.figure()
    plt.imshow(sample)
    plt.savefig(file_path, bbox_inches='tight')


def test_input_gradient():
    images = get_test_image()
    model = get_model()

    prediction = predict(model, images)
    print(prediction)

    gradients = K.gradients(model.output, model.input)              #Gradient of output wrt the input of the model (Tensor)
    print(gradients)

    sess = K.get_session()
    evaluated_gradients = sess.run(gradients[0], feed_dict={model.input:
    images})

    visualize_sample((evaluated_gradients[0]*(10**9.5)).clip(0,255), "test.png")


if __name__ == "__main__":
    test_input_gradient()

Output: Output:

[('ladybug', 0.53532666)]
[<tf.Tensor 'gradients/block1_conv1/convolution_grad/Conv2DBackpropInput:0' shape=(?, 224, 224, 3) dtype=float32>]

It seems the code is taking the gradients of the outputs wrt the inputs.似乎代码正在获取输入的输出梯度。
So, this is just taking a single slice from the outputs.所以,这只是从输出中取出一个切片。

Warning: This considers a regular model output.警告:这考虑了常规 model output。 I have no idea of what you're doing in decode predictions and the following list.我不知道您在解码预测和以下列表中在做什么。

gradients = K.gradients(model.output[:, lady_bug_class], model.input)   

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Keras:输出wrt输入的梯度作为分类器的输入 - Keras : Gradients of output w.r.t. input as input to classifier 如何在Django中获取带有URL的http响应代码? - How to get http response code w.r.t. a url in django? 如何链接交互式问题(写 CodeJam)? - How to link interactive problems (w.r.t. CodeJam)? 如何在Tensorflow中获得损失wrt模型预测的梯度? - How can I get the gradient of the loss w.r.t. model prediction in Tensorflow? 如何在优化器中获得偏差和神经元权重? - How to get bias and neuron weights in optimizer? Tensorflow-如何使用模型参数获取输出的梯度 - Tensorflow - How to get the gradients of the output w.r.t the model parameters 如何使用 autograd.grad 计算 PyTorch 中的参数损失的 Hessian - How to compute Hessian of the loss w.r.t. the parameters in PyTorch using autograd.grad 如何维护长期存在的python项目w.r.t.依赖项和python版本? - How to maintain long-lived python projects w.r.t. dependencies and python versions? Python存储库组织WRT测试和PyDev - Python repository organization w.r.t. tests and PyDev 在scrapy中使用链接提取器时如何提取请求网址和响应网址? - How to extract request url w.r.t. response url when using link extractor in scrapy?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM