简体   繁体   English

Tensorflow 2 在自定义层内禁用了急切执行

[英]Tensorflow 2 eager execution disabled inside a custom layer

I'm using TF2 installed via pip in a ubuntu 18.04 box我在 ubuntu 18.04 盒子中使用通过 pip 安装的 TF2

$ pip freeze | grep "tensorflow"

And I'm playing with a custom layer.我正在玩自定义图层。

import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Input, Concatenate, Dense, Bidirectional, LSTM, Embedding
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import imdb

class Attention(tf.keras.layers.Layer):

    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

vocab_size = 10000

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)

max_len = 200
rnn_cell_size = 128

x_train = sequence.pad_sequences(x_train, maxlen=max_len, padding='post')
x_test = sequence.pad_sequences(x_test, maxlen=max_len, truncating='post', padding='post')

# Network

sequence_input = Input(shape=(max_len,), dtype='int32')

embedded_sequences = Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

# lstm = Bidirectional(LSTM(rnn_cell_size, dropout=0.3, return_sequences=True, return_state=True), name="bi_lstm_0")(embedded_sequences)

lstm, forward_h, forward_c, backward_h, backward_c = Bidirectional(LSTM(rnn_cell_size, dropout=0.2, return_sequences=True, return_state=True))(embedded_sequences)

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])

attention = Attention(8)

context_vector, attention_weights = attention(lstm, state_h)

output = Dense(1, activation='sigmoid')(context_vector)

model = Model(inputs=sequence_input, outputs=output)

# summarize layers

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=10, batch_size=200, validation_split=.3, verbose=1)

result = model.evaluate(x_test, y_test)

I would like to debug/inspect the Attention.call() function, but I'm not able to get the tensors values when a set a breakpoint inside the funcion.我想调试/检查Attention.call() function,但是在函数内设置断点时我无法获取张量值。

Before I start the.fit(), I can verify that the eager execution is Enabled在我启动 .fit() 之前,我可以验证急切执行是否已启用


But inside the Attention.call() function the eager execution is Disabled但是在 Attention.call() function 内部,急切执行被禁用


Any reason for the eager execution be false during the call() execution?在 call() 执行期间急切执行的任何原因是错误的? How to enable it?如何启用它?

By default, tf.keras model is compiled to a static graph to deliver the best execution performance.默认情况下, tf.keras model 被编译为 static 图以提供最佳执行性能。 Just think that @tf.function is by default annotated for tf.keras model.想想@tf.function默认注释为tf.keras model。

https://www.tensorflow.org/api_docs/python/tf/keras/Model#run_eagerly https://www.tensorflow.org/api_docs/python/tf/keras/Model#run_eagerly

To enable eager mode explicitly for tf.keras model, in your code, compile the model with run_eagerly=True .要为tf.keras model 显式启用渴望模式,请在您的代码中使用run_eagerly=True编译 model。

model.compile(optimizer='adam', run_eagerly = True, loss='binary_crossentropy', metrics=['accuracy'])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM