[英]tensorflow 2.0: tf.GradientTape().gradient() returns None
[英]How to log model graph to tensorboard when using combination to Functional API and tf.GradientTape() to train in Tensorflow 2.0?
当我使用 Keras Functional API 或模型子调用 API 来创建模型和 tf.GradientTape() 来训练模型时,有人可以指导我如何将模型图记录到张量板吗?
# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
print('Start of epoch %d' % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
if step % 200 == 0:
print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
print('Seen so far: %s samples' % ((step + 1) * 64))
我应该在哪里插入模型图的张量板日志记录?
执行此操作的最佳方法(TF 2.3.0,TB 2.3.0)是使用tf.summary
并通过带有@tf.function
装饰器的包装函数传递模型。
在您的情况下,要将模型图导出到 TensorBoard 进行检查:
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Solution Code:
writer = tf.summary.create_file_writer('./logs/')
@tf.function
def init_model(data, model):
model(data)
tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)
with writer.as_default():
tf.summary.trace_export('name', step=0)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.