[英]Calculating the derivates of the output with respect to input for a give time step in LSTM tensorflow2.0
我寫了一個示例代碼來生成我在我的項目中面臨的真正問題。 我在 tensorflow 到 model 一些時間序列數據中使用 LSTM。 輸入維度為(10, 100, 1)
,即 10 個實例,100 個時間步,特征數為 1。output 具有相同的形狀。
在訓練 model 之后我想要實現的是研究每個輸入在每個特定時間步對每個 output 的影響。 換句話說,我想在每個時間步查看哪些輸入變量對我的 output 影響最大(或者哪個輸入對輸出影響最大/可能是大梯度)。 這是這個問題的代碼:
tf.keras.backend.clear_session()
tf.random.set_seed(42)
model_input = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_input = model_input.batch(10)
model_output = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_output = model_output.batch(10)
my_dataset = tf.data.Dataset.zip((model_input, model_output))
m_inputs = tf.keras.Input(shape=(None, 1))
lstm_outputs = tf.keras.layers.LSTM(32, return_sequences=True)(m_inputs)
m_outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1))(lstm_outputs)
my_model = tf.keras.Model(m_inputs, m_outputs, name="my_model")
my_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
my_loss_fn = tf.keras.losses.MeanSquaredError()
my_epochs = 3
for epoch in range(my_epochs):
for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
x += 1
# open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
with tf.GradientTape() as tape:
# Run the forward pass of the layer
logits = my_model(x_batch_tr, training=True)
# compute the loss value for this mismatch
loss_value = my_loss_fn(y_batch_tr, logits)
# use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, my_model.trainable_weights)
# Run one step of gradient descent by updating the value of the variables to minimize the loss.
my_optimizer.apply_gradients(zip(grads, my_model.trainable_weights))
print(f"Step {step}, loss: {loss_value}")
print("\n\nCalculate gradient of ouptuts w.r.t inputs\n\n")
for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
# open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
with tf.GradientTape() as tape:
tape.watch(x_batch_tr)
# Run the forward pass of the layer
logits = my_model(x_batch_tr, training=True)
#tape.watch(logits[:, 10, :]) # this didn't help
# compute the loss value for this mismatch
loss_value = my_loss_fn(y_batch_tr, logits)
# use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
# grads = tape.gradient(logits, x_batch_tr) # This works
# print(grads.numpy().shape) # This works
grads = tape.gradient(logits[:, 10, :], x_batch_tr)
print(grads)
換句話說,我想關注對我的 output 影響最大的輸入(在每個特定時間步)。
對我來說grads = tape.gradient(logits, x_batch_tr)
不會做這個工作因為這會添加所有輸出的梯度 w.r.t 每個輸入。
但是,梯度始終為 None。
任何幫助深表感謝!
您可以使用tf.GradientTape.batch_jacobian
來准確獲取該信息:
grads = tape.batch_jacobian(logits, x_batch_tr)
print(grads.shape)
# (10, 100, 1, 100, 1)
在這里, grads[ i
grads[i, t1, f1, t2, f2]
為您提供了 output 特征f1
在時間t1
相對於輸入特征f2
在時間t2
的梯度。 如果像你的情況一樣,你只有一個特征,你可以說grads[i, t1, 0, t2, 0]
給你t1
相對於t2
的梯度。 方便的是,您還可以聚合此結果的不同軸或切片以獲得聚合梯度。 例如, tf.reduce_sum(grads[:, :, :, :10], axis=3)
將為您提供每個 output 時間步相對於前十個輸入時間步的梯度。
關於在你的例子中獲得None
漸變,我認為這是因為你在漸變帶上下文之外進行切片操作,所以漸變跟蹤丟失了。
所以解決方案是為我們需要在tape.grad
中使用的部分 logit 創建一個臨時張量,並使用tape.watch
在磁帶上注冊該張量
應該這樣做:
for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
# open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
with tf.GradientTape() as tape:
tape.watch(x_batch_tr)
# Run the forward pass of the layer
logits = my_model(x_batch_tr, training=True)
tensor_logits = tf.constant(logits[:, 10, :])
tape.watch(tensor_logits) # this didn't help
# compute the loss value for this mismatch
loss_value = my_loss_fn(y_batch_tr, logits)
# use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(tensor_logits, x_batch_tr)
print(grads.numpy())
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.