簡體   English   中英

使用功能 API 和 tf.GradientTape() 的組合在 Tensorflow 2.0 中進行訓練時,如何將模型圖記錄到張量板?

[英]How to log model graph to tensorboard when using combination to Functional API and tf.GradientTape() to train in Tensorflow 2.0?

當我使用 Keras Functional API 或模型子調用 API 來創建模型和 tf.GradientTape() 來訓練模型時,有人可以指導我如何將模型圖記錄到張量板嗎?

# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()


batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):   

    with tf.GradientTape() as tape:
        logits = model(x_batch_train)
        loss_value = loss_fn(y_batch_train, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)


    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    if step % 200 == 0:
        print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
        print('Seen so far: %s samples' % ((step + 1) * 64))

我應該在哪里插入模型圖的張量板日志記錄?

執行此操作的最佳方法(TF 2.3.0,TB 2.3.0)是使用tf.summary並通過帶有@tf.function裝飾器的包裝函數傳遞模型。

在您的情況下,要將模型圖導出到 TensorBoard 進行檢查:

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


# Solution Code:
writer = tf.summary.create_file_writer('./logs/')

@tf.function
def init_model(data, model):
    model(data)

tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)

with writer.as_default():
    tf.summary.trace_export('name', step=0)

查看 TensorBoard 中的日志文件,您應該會看到這樣的模型圖。 在此處輸入圖片說明

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM