简体   繁体   English

Tensorflow 2 在自定义层内禁用了急切执行

[英]Tensorflow 2 eager execution disabled inside a custom layer

I'm using TF2 installed via pip in a ubuntu 18.04 box我在 ubuntu 18.04 盒子中使用通过 pip 安装的 TF2

$ pip freeze | grep "tensorflow"
tensorflow==2.0.0
tensorflow-estimator==2.0.1

And I'm playing with a custom layer.我正在玩自定义图层。

import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Input, Concatenate, Dense, Bidirectional, LSTM, Embedding
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import imdb

class Attention(tf.keras.layers.Layer):

    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

vocab_size = 10000

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)

max_len = 200
rnn_cell_size = 128

x_train = sequence.pad_sequences(x_train, maxlen=max_len, padding='post')
x_test = sequence.pad_sequences(x_test, maxlen=max_len, truncating='post', padding='post')

# Network

sequence_input = Input(shape=(max_len,), dtype='int32')

embedded_sequences = Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

# lstm = Bidirectional(LSTM(rnn_cell_size, dropout=0.3, return_sequences=True, return_state=True), name="bi_lstm_0")(embedded_sequences)

lstm, forward_h, forward_c, backward_h, backward_c = Bidirectional(LSTM(rnn_cell_size, dropout=0.2, return_sequences=True, return_state=True))(embedded_sequences)

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])

attention = Attention(8)

context_vector, attention_weights = attention(lstm, state_h)

output = Dense(1, activation='sigmoid')(context_vector)

model = Model(inputs=sequence_input, outputs=output)

# summarize layers
print(model.summary())

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=10, batch_size=200, validation_split=.3, verbose=1)

result = model.evaluate(x_test, y_test)
print(result)

I would like to debug/inspect the Attention.call() function, but I'm not able to get the tensors values when a set a breakpoint inside the funcion.我想调试/检查Attention.call() function,但是在函数内设置断点时我无法获取张量值。

Before I start the.fit(), I can verify that the eager execution is Enabled在我启动 .fit() 之前,我可以验证急切执行是否已启用

print(tf.executing_eagerly())
True

But inside the Attention.call() function the eager execution is Disabled但是在 Attention.call() function 内部,急切执行被禁用

print(tf.executing_eagerly())
False

Any reason for the eager execution be false during the call() execution?在 call() 执行期间急切执行的任何原因是错误的? How to enable it?如何启用它?

By default, tf.keras model is compiled to a static graph to deliver the best execution performance.默认情况下, tf.keras model 被编译为 static 图以提供最佳执行性能。 Just think that @tf.function is by default annotated for tf.keras model.想想@tf.function默认注释为tf.keras model。

https://www.tensorflow.org/api_docs/python/tf/keras/Model#run_eagerly https://www.tensorflow.org/api_docs/python/tf/keras/Model#run_eagerly

To enable eager mode explicitly for tf.keras model, in your code, compile the model with run_eagerly=True .要为tf.keras model 显式启用渴望模式,请在您的代码中使用run_eagerly=True编译 model。

model.compile(optimizer='adam', run_eagerly = True, loss='binary_crossentropy', metrics=['accuracy'])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM