简体   繁体   中英

TensorFlow model to Keras functional API?

I want to have this model as a functional model that uses Keras API, but not sure how. I want my model to be in the form of model = tf.keras.model.Model(....) so I can just evaluate or export the model by calling model . But I don't know how to do this with attention layers in the model. The Keras attention layer documentation stops at that very step and leave it to the user to figure it out.

FYI, my model uses IMDB reviews for sentiment analysis.

query_layer = tf.keras.layers.Conv1D(filters=100, kernel_size=4, padding='same')
value_layer = tf.keras.layers.Conv1D(filters=100, kernel_size=4, padding='same')

attention = tf.keras.layers.Attention()
concat = tf.keras.layers.Concatenate()

cells = [tf.keras.layers.LSTMCell(256), tf.keras.layers.LSTMCell(64)]
rnn = tf.keras.layers.RNN(cells)
output_layer = tf.keras.layers.Dense(1)

for batch in ds['train'].batch(32):
    text = batch['text']
    embeddings = embedding_layer(vectorize_layer(text))
    query = query_layer(embeddings)
    value = value_layer(embeddings)
    query_value_attention = attention([query, value])
    attended_values = concat([query, query_value_attention])
    logits = output_layer(rnn(attended_values))
    loss = binary_crossentropy(tf.expand_dims(batch['label'], -1),
                                               logits, from_logits=True)

Not sure why you had the "for".

Here is an example based on the keras documentation. I added a dense layer on the output.

import tensorflow as tf

'''
query_layer = tf.keras.layers.Conv1D(filters=100, kernel_size=4, padding='same')
value_layer = tf.keras.layers.Conv1D(filters=100, kernel_size=4, padding='same')

attention = tf.keras.layers.Attention()
concat = tf.keras.layers.Concatenate()

cells = [tf.keras.layers.LSTMCell(256), tf.keras.layers.LSTMCell(64)]
rnn = tf.keras.layers.RNN(cells)
output_layer = tf.keras.layers.Dense(1)

for batch in ds['train'].batch(32):
    text = batch['text']
    embeddings = embedding_layer(vectorize_layer(text))
    query = query_layer(embeddings)
    value = value_layer(embeddings)
    query_value_attention = attention([query, value])
    attended_values = concat([query, query_value_attention])
    logits = output_layer(rnn(attended_values))
    loss = binary_crossentropy(tf.expand_dims(batch['label'], -1),
                                               logits, from_logits=True)
'''

query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32')

# Embedding lookup.
token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input)

# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
    filters=100,
    kernel_size=4,
    # Use 'same' padding so outputs have the same shape as inputs.
    padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings)

# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.Attention()(
    [query_seq_encoding, value_seq_encoding])

# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
    query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
    query_value_attention_seq)

# Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
    [query_encoding, query_value_attention])

# Add DNN layers, and create Model.
output_layer = tf.keras.layers.Dense(1)(input_layer)

model = tf.keras.models.Model(inputs=[query_input, value_input], outputs = output_layer)
model.compile(optimizer='adam', loss='binary_crossentropy')

model.summary()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM