[英]Does @tf.function decorator work with class attributes?
我目前正在开发一个 Autoencoder 类 - 其中一种方法如下:
@tf.function
def build_ae(self, inputs):
self.input_ae = inputs
self.encoded = self.encoder(self.input_ae)
self.decoded = self.decoder(self.encoded)
self.autoencoder = Model(self.input_ae,outputs=self.encoded)
如果我尝试运行此功能:
inp = Input(shape=(Nx, Nx, Nu)) # Nx, Nu are both int
ae = Autoencoder(Nx, Nu, [32, 64, 128]) # Simply specifying layers + input dimensions
ae.build_ae(inp)
我收到以下错误:
TypeError: Cannot convert a symbolic Keras input/output to a numpy array.
This error may indicate that you're trying to pass a symbolic value to a NumPy call,
which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs
to a TF API that does not register dispatching, preventing Keras from automatically
converting the API call to a lambda layer in the Functional Model.
但是,当我删除@tf.function
装饰器时,该函数按预期工作。
我试过写一个简单的测试示例:
class Test:
@tf.function
def build_test(self, inputs):
self.inp = inputs
t = Test()
input_t = Input(shape=(3,3,3))
t.build_test(input_t)
再一次,这会导致相同的错误。
我试过禁用急切执行,但这没有效果。
有谁知道为什么这可能不起作用?
更新:
这是完整的 Autoencoder 类:
import einops
import h5py
from pathlib import Path
from typing import List, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.layers import (
Dense,
Conv2D,
MaxPool2D,
UpSampling2D,
concatenate,
BatchNormalization,
Conv2DTranspose,
Flatten,
PReLU,
Reshape,
Dropout,
AveragePooling2D,
add,
Lambda,
Layer,
TimeDistributed,
LSTM
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
class Autoencoder:
def __init__(self,
nx: int,
nu: int,
features_layers: List[int],
latent_dim: int = 16,
filter_window: Tuple[int, int] = (3, 3),
act_fn: str = 'tanh',
batch_norm: bool = False,
dropout_rate: float = 0.0,
lmb: float = 0.0,
resize_method: str = 'bilinear') -> None:
self.nx = nx
self.nu = nu
self.features_layers = features_layers
self.latent_dim = latent_dim
self.filter_window = filter_window
self.act_fn = act_fn
self.batch_norm = batch_norm
self.dropout_rate = dropout_rate
self.lmb = lmb
self.resize_method = resize_method
self.train_history = None
self.val_history = None
self.encoder = Encoder(
self.nx,
self.nu,
self.features_layers,
self.latent_dim,
self.filter_window,
self.act_fn,
self.batch_norm,
self.dropout_rate,
self.lmb
)
self.decoder = Decoder(
self.nx,
self.nu,
self.features_layers,
self.latent_dim,
self.filter_window,
self.act_fn,
self.batch_norm,
self.dropout_rate,
self.lmb,
self.resize_method
)
@tf.function
def build_ae(self, inputs):
self.input_ae = inputs
self.encoded = self.encoder(self.input_ae)
self.decoded = self.decoder(self.encoded)
self.autoencoder = Model(self.input_ae, outputs=self.decoded)
@tf.function
def compile_ae(self, learning_rate, loss):
self.autoencoder.compile(optimizer=Adam(learning_rate=learning_rate), loss=loss)
def train(self,
inputs,
targets,
inputs_valid,
targets_valid,
n_epoch,
batch_size,
learning_rate,
patience,
filepath):
model_cb = ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True, verbose=1, save_format="h5")
early_cb = EarlyStopping(monitor='val_loss', patience=patience, verbose=1)
cb = [model_cb, early_cb]
self.train_history = []
self.val_history = []
tf.keras.backend.set_value(self.autoencoder.optimizer.lr, learning_rate)
hist = self.autoencoder.fit(
inputs,
targets,
epochs=n_epoch,
batch_size=batch_size,
shuffle=True,
validation_data=(inputs_valid, targets_valid),
callbacks=cb
)
self.train_history.extend(hist.history['loss'])
self.train_history.extend(hist.history['val_loss'])
return self.train_history, self.val_history
以下方法似乎都可以完成这项工作:
tf.config.run_functions_eagerly
class Test:
@tf.function
def build_test(self, inputs):
self.inp = inputs
t = Test()
input_t = tf.keras.layers.Input(shape=(3,3,3))
t.build_test(input_t)
和
class Test:
@tf.function
def build_test(self, inputs):
self.inp = inputs
t = Test()
t.build_test(tf.constant(1))
根据文档,当您创建tf.keras.Model
:
默认情况下,我们会尝试将您的模型编译为静态图以提供最佳执行性能。
您已经在build_ae
方法中创建了一个模型,所以我认为省略@tf.function
装饰器不会影响您的性能。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.