简体   繁体   中英

How can I implement a custom PCA layer in my model using Model Subclassing API?

I am trying to implement a custom PCA layer for my model being developed using Model Subclassing API. This is how I have defined the layer.

class PCALayer(tf.keras.layers.Layer):
    def __init__(self):
        super(PCALayer, self).__init__()
    
        self.pc = pca

    def call(self, input_tensor, training=False):
       x = K.constant(self.pc.transform(input_tensor)) 
       return x 

The pca itself is from sklearn.decomposition.PCA and has been fit with the needed data and not transformed.

Now, this is how I have added the layer to my model

class ModelSubClassing(tf.keras.Model):
    def __init__(self, initizlizer):
        super(ModelSubClassing, self).__init__()
        # define all layers in init
        # Layer of Block 1
        self.pca_layer = PCALayer()
        self.dense1 = tf.keras.layers.Dense(...)
        self.dense2 = tf.keras.layers.Dense(...)
        self.dense3 = tf.keras.layers.Dense(...)


    def call(self, input_tensor, training=False):
        # forward pass: block 1 
        x = self.pca_layer(input_tensor)
        x = self.dense1(x)
        x = self.dense2(x)
        
        return self.dense3(x)

When I compile the model there is no error. However, when I fit the model, I get the following error:

NotImplementedError: Cannot convert a symbolic Tensor (model_sub_classing_1/Cast:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

Can anyone help me please...

self.pc.transform which comes from sklearn is expecting a numpy array, but you provide a tf tensor. When the layer is built, it passes a symbolic tensor to build the graph etc, and this tensor cannot be converted to a numpy array. The answer is in error:

you're trying to pass a Tensor to a NumPy call, which is not supported

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM