简体   繁体   English

如何修复keras子类模型中的批量大小?

[英]How to fix the batch size in keras subclassing model?

In tf.keras functional API, I can fix the batch size like below:在 tf.keras 函数式 API 中,我可以像下面这样固定批量大小:

import tensorflow as tf

inputs = tf.keras.Input(shape=(64, 64, 3), batch_size=1)    # I can fix batch size like this
x = tf.keras.layers.Conv2DTranspose(3, 3, strides=2, padding="same", activation="relu")(inputs)
outputs = x
model = keras.Model(inputs=inputs, outputs=outputs, name="custom")

My question is, how do I can fix the batch size when I use the keras subclassing approach?我的问题是,当我使用 keras 子类化方法时,如何修复批量大小?

One way of dealing with your parameters indirectly (when there's no reach to it) is using tf.keras.backend access.间接处理参数的一种方法(当无法访问它时)是使用tf.keras.backend访问。 In this case, tf defines the input format through call function:在这种情况下,tf 通过调用函数来定义输入格式:

def call(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

And then iterates over each batch然后遍历每个批次

for step, x_batch_train in enumerate(train_dataset):
    with tf.GradientTape() as tape:
        reconstructed = vae(x_batch_train)
        # Compute reconstruction loss
        loss = mse_loss_fn(x_batch_train, reconstructed)
        loss += sum(vae.losses)  # Add KLD regularization loss

    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
        print("step %d: mean loss = %.4f" % (step, loss_metric.result())

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM