简体   繁体   中英

How to log model graph to tensorboard when using combination to Functional API and tf.GradientTape() to train in Tensorflow 2.0?

Can some please guide me on how to log model graph to tensorboard when I am using Keras Functional API or the model sub calling API to create the model and tf.GradientTape() to train the model?

# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()


batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):   

    with tf.GradientTape() as tape:
        logits = model(x_batch_train)
        loss_value = loss_fn(y_batch_train, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)


    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    if step % 200 == 0:
        print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
        print('Seen so far: %s samples' % ((step + 1) * 64))

Where should I insert the tensorboard logging for the model graph?

The best way to do this (TF 2.3.0, TB 2.3.0) is to use tf.summary and pass the model through a wrapper function with the @tf.function decorator.

In your case, to export the model graph to TensorBoard for inspection:

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


# Solution Code:
writer = tf.summary.create_file_writer('./logs/')

@tf.function
def init_model(data, model):
    model(data)

tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)

with writer.as_default():
    tf.summary.trace_export('name', step=0)

Looking at the logs files in TensorBoard you should see a model graph as such. 在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM