繁体   English   中英

如何修复keras子类模型中的批量大小?

[英]How to fix the batch size in keras subclassing model?

在 tf.keras 函数式 API 中,我可以像下面这样固定批量大小:

import tensorflow as tf

inputs = tf.keras.Input(shape=(64, 64, 3), batch_size=1)    # I can fix batch size like this
x = tf.keras.layers.Conv2DTranspose(3, 3, strides=2, padding="same", activation="relu")(inputs)
outputs = x
model = keras.Model(inputs=inputs, outputs=outputs, name="custom")

我的问题是,当我使用 keras 子类化方法时,如何修复批量大小?

间接处理参数的一种方法(当无法访问它时)是使用tf.keras.backend访问。 在这种情况下,tf 通过调用函数来定义输入格式:

def call(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

然后遍历每个批次

for step, x_batch_train in enumerate(train_dataset):
    with tf.GradientTape() as tape:
        reconstructed = vae(x_batch_train)
        # Compute reconstruction loss
        loss = mse_loss_fn(x_batch_train, reconstructed)
        loss += sum(vae.losses)  # Add KLD regularization loss

    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
        print("step %d: mean loss = %.4f" % (step, loss_metric.result())

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM