简体   繁体   中英

Does @tf.function decorator work with class attributes?

I'm currently developing an Autoencoder class - one of the methods is as follows:

@tf.function
def build_ae(self, inputs):
        
    self.input_ae = inputs
    self.encoded = self.encoder(self.input_ae)
    self.decoded = self.decoder(self.encoded)
    self.autoencoder = Model(self.input_ae,outputs=self.encoded)

If I try to run this function:

inp = Input(shape=(Nx, Nx, Nu))         # Nx, Nu are both int
ae = Autoencoder(Nx, Nu, [32, 64, 128]) # Simply specifying layers + input dimensions 
ae.build_ae(inp)

I get the following error:

TypeError: Cannot convert a symbolic Keras input/output to a numpy array. 
This error may indicate that you're trying to pass a symbolic value to a NumPy call,
which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs 
to a TF API that does not register dispatching, preventing Keras from automatically
converting the API call to a lambda layer in the Functional Model.

However, when I remove the @tf.function decorator, the function works as intended.

I've tried writing a simple test example:

class Test:
    
    @tf.function
    def build_test(self, inputs):
        self.inp = inputs
        
t = Test()

input_t = Input(shape=(3,3,3))
t.build_test(input_t)

Once again, this results in the same error.

I've tried disabling eager execution and this has had no effect.

Does anyone know why this might not be working?

Update:

Here is the full Autoencoder class:

import einops

import h5py
from pathlib import Path
from typing import List, Tuple

import numpy as np
import tensorflow as tf

from tensorflow.keras import Input

from tensorflow.keras.layers import (
    Dense, 
    Conv2D,
    MaxPool2D,
    UpSampling2D,
    concatenate,
    BatchNormalization,
    Conv2DTranspose,
    Flatten,
    PReLU,
    Reshape,
    Dropout,
    AveragePooling2D,
    add,
    Lambda,
    Layer,
    TimeDistributed,
    LSTM
)

from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

class Autoencoder:
    
    def __init__(self,
                 nx: int,
                 nu: int,
                 features_layers: List[int],
                 latent_dim: int = 16,
                 filter_window: Tuple[int, int] = (3, 3),
                 act_fn: str = 'tanh',
                 batch_norm: bool = False,
                 dropout_rate: float = 0.0,
                 lmb: float = 0.0,
                 resize_method: str = 'bilinear') -> None:
        
        self.nx = nx
        self.nu = nu
        
        self.features_layers = features_layers
        self.latent_dim = latent_dim
        self.filter_window = filter_window
        
        self.act_fn = act_fn
        self.batch_norm = batch_norm
        self.dropout_rate = dropout_rate
        self.lmb = lmb
        self.resize_method = resize_method
        
        self.train_history = None
        self.val_history = None
        
        self.encoder = Encoder(
            self.nx,
            self.nu,
            self.features_layers,
            self.latent_dim,
            self.filter_window,
            self.act_fn,
            self.batch_norm,
            self.dropout_rate,
            self.lmb
        )
        
        self.decoder = Decoder(
            self.nx,
            self.nu,
            self.features_layers,
            self.latent_dim,
            self.filter_window,
            self.act_fn,
            self.batch_norm,
            self.dropout_rate,
            self.lmb,
            self.resize_method
        )
    
    @tf.function
    def build_ae(self, inputs):
        
        self.input_ae = inputs
        self.encoded = self.encoder(self.input_ae)
        self.decoded = self.decoder(self.encoded)
        self.autoencoder = Model(self.input_ae, outputs=self.decoded)
        
    @tf.function
    def compile_ae(self, learning_rate, loss):
        self.autoencoder.compile(optimizer=Adam(learning_rate=learning_rate), loss=loss)
        
    def train(self,
              inputs,
              targets,
              inputs_valid,
              targets_valid,
              n_epoch,
              batch_size,
              learning_rate,
              patience,
              filepath):
        
        model_cb = ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True, verbose=1, save_format="h5")
        early_cb = EarlyStopping(monitor='val_loss', patience=patience, verbose=1)
        
        cb = [model_cb, early_cb]
        
        self.train_history = []
        self.val_history = []
    
        tf.keras.backend.set_value(self.autoencoder.optimizer.lr, learning_rate)
        
        hist = self.autoencoder.fit(
            inputs,
            targets,
            epochs=n_epoch,
            batch_size=batch_size,
            shuffle=True,
            validation_data=(inputs_valid, targets_valid),
            callbacks=cb
        )
        
        self.train_history.extend(hist.history['loss'])
        self.train_history.extend(hist.history['val_loss'])
        
        return self.train_history, self.val_history

The following methods both seem to do the job:

tf.config.run_functions_eagerly

class Test:
    
    @tf.function
    def build_test(self, inputs):
        self.inp = inputs
        
t = Test()
input_t = tf.keras.layers.Input(shape=(3,3,3))
t.build_test(input_t)

and

class Test:
    
    @tf.function
    def build_test(self, inputs):
        self.inp = inputs
        
t = Test()
t.build_test(tf.constant(1))

According to the docs , when you create a tf.keras.Model :

By default, we will attempt to compile your model to a static graph to deliver the best execution performance.

You are already creating a model in your build_ae method, so I don't think that omitting the @tf.function decorator will affect your performance.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM