[英]How can I implement a custom PCA layer in my model using Model Subclassing API?
I am trying to implement a custom PCA layer for my model being developed using Model Subclassing API. This is how I have defined the layer.我正在尝试为正在使用 Model 子类 API 开发的 model 实现自定义 PCA 层。这就是我定义层的方式。
class PCALayer(tf.keras.layers.Layer):
def __init__(self):
super(PCALayer, self).__init__()
self.pc = pca
def call(self, input_tensor, training=False):
x = K.constant(self.pc.transform(input_tensor))
return x
The pca itself is from sklearn.decomposition.PCA
and has been fit with the needed data and not transformed. pca 本身来自sklearn.decomposition.PCA
并且已经适合所需的数据并且没有被转换。
Now, this is how I have added the layer to my model现在,这就是我将图层添加到我的 model 的方式
class ModelSubClassing(tf.keras.Model):
def __init__(self, initizlizer):
super(ModelSubClassing, self).__init__()
# define all layers in init
# Layer of Block 1
self.pca_layer = PCALayer()
self.dense1 = tf.keras.layers.Dense(...)
self.dense2 = tf.keras.layers.Dense(...)
self.dense3 = tf.keras.layers.Dense(...)
def call(self, input_tensor, training=False):
# forward pass: block 1
x = self.pca_layer(input_tensor)
x = self.dense1(x)
x = self.dense2(x)
return self.dense3(x)
When I compile the model there is no error.当我编译 model 时没有错误。 However, when I fit the model, I get the following error:但是,当我安装 model 时,出现以下错误:
NotImplementedError: Cannot convert a symbolic Tensor (model_sub_classing_1/Cast:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported
Can anyone help me please...谁能帮帮我...
self.pc.transform
which comes from sklearn is expecting a numpy array, but you provide a tf tensor.来自self.pc.transform
的 self.pc.transform 需要一个 numpy 数组,但您提供了一个 tf 张量。 When the layer is built, it passes a symbolic tensor to build the graph etc, and this tensor cannot be converted to a numpy array.构建图层时,它传递一个符号张量来构建图形等,而这个张量不能转换为 numpy 数组。 The answer is in error:答案有误:
you're trying to pass a Tensor to a NumPy call, which is not supported您正在尝试将 Tensor 传递给 NumPy 调用,这是不支持的
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.