Can some please guide me on how to log model graph to tensorboard when I am using Keras Functional API or the model sub calling API to create the model and tf.GradientTape() to train the model?
# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
print('Start of epoch %d' % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
if step % 200 == 0:
print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
print('Seen so far: %s samples' % ((step + 1) * 64))
Where should I insert the tensorboard logging for the model graph?
The best way to do this (TF 2.3.0, TB 2.3.0) is to use tf.summary
and pass the model through a wrapper function with the @tf.function
decorator.
In your case, to export the model graph to TensorBoard for inspection:
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Solution Code:
writer = tf.summary.create_file_writer('./logs/')
@tf.function
def init_model(data, model):
model(data)
tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)
with writer.as_default():
tf.summary.trace_export('name', step=0)
Looking at the logs files in TensorBoard you should see a model graph as such.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.