简体   繁体   English

如何从keras API模型获取tf.gradients?

[英]How to get tf.gradients from keras API model?

I would like to know how to get tf.gradients from a model built using the Keras API. 我想知道如何从使用tf.gradients API构建的模型中获取tf.gradients

import Tensorflow as tf
from tensorflow import keras
from sklearn.datasets.samples_generator import make_blobs

# Create the model
inputs = keras.Input(shape=(2,))
x = keras.layers.Dense(12, activation='relu')(inputs)
x = keras.layers.Dense(8, activation='relu')(x)
predictions = keras.layers.Dense(3, activation='softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

# Generate random data
X, y = make_blobs(n_samples=1000, centers=3, n_features=2)
labels = keras.utils.to_categorical(y, num_classes=3)

# Compute the gradients wrt inputs
y_true = tf.convert_to_tensor(labels)
y_pred = tf.convert_to_tensor(np.round(model.predict(X)))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
grads = tf.gradients(model.loss_functions[0](y_true, y_pred), 
                      model.inputs[0])
sess.run(grads, input_dict={model.inputs[0]: X, model.outputs: y})

First attempt above: my grads are None . 上面的第一次尝试:我的毕业生为None With my second try below: 在下面的第二次尝试中:

sess.run(grads, input_dict={model.inputs: X, model.outputs: y })

I get the following error: 我收到以下错误:

TypeError: unhashable type: 'list'

I think you shouldn't create a new session directly with Tensorflow when using Keras. 我认为使用Keras时不应直接与Tensorflow创建新会话。 Instead, it is better to use the session implicitly created by Keras: 相反,最好使用Keras隐式创建的会话:

import keras.backend as K

sess = K.get_session()

However, I think in this case you don't need to retrieve the session at all. 但是,我认为在这种情况下,您根本不需要检索会话。 You can easily use backend functions, like K.gradients() and K.function() , to achieve your aim. 您可以轻松使用后端函数(例如K.gradients()K.function()来实现您的目标。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM