繁体   English   中英

自定义图层 Tensorflow 2.1 output 形状的问题

[英]custom layer with Tensorflow 2.1 problem with the output shape

我试图让自定义层返回 (25,1) 张量,但是有一个 batch_size 应该通过(我从下一层得到错误)。 我查找示例,但无法确定如何指定 output 形状。

此外,我需要一个独立于输入大小的任意 output 形状,因为计算(不是下面示例的一部分)将始终返回固定数量的值。

我尝试了以下内容:

class SimpleLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(SimpleLayer,  self).__init__(**kwargs)
        self.baseline = tf.Variable(initial_value=0.1, trainable=True)

    def call(self, inputs):
        print ("in call inputs:", inputs.shape)
        ret = tf.zeros((25, 1)) + self.baseline
        print("Ret:", ret, "Shape", tf.shape(ret))
        return (ret)

这返回:

Ret: Tensor("om/add:0", shape=(25, 1), dtype=float32) Shape Tensor("om/Shape:0", shape=(2,), dtype=int32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
inputs (InputLayer)          [(None, 150, 1)]          0         
_________________________________________________________________
dense (Dense)                (None, 150, 256)          512       
_________________________________________________________________
om (SimpleLayer)             (25, 1)                   1         
=================================================================

但这确实形成了 output 形状 (25, 1) 但不是 (None, 25, 1)。

然后我尝试了:

class SimpleLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(SimpleLayer,  self).__init__(**kwargs)
        self.baseline = tf.Variable(initial_value=0.1, trainable=True)

    def call(self, inputs):
        print ("in call inputs:", inputs.shape)
        ret = tf.zeros((25, 1)) + self.baseline
        return (ret)

并得到错误:

TypeError: Expected int32, got None of type 'NoneType' instead.

有什么建议吗?

我建议您使用调用方法中定义的输入数据,否则该层没有意义

我提供了一个虚拟示例并完美运行

class SimpleLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(SimpleLayer,  self).__init__(**kwargs)
        self.baseline = tf.Variable(initial_value=0.1, trainable=True)

    def call(self, inputs):
        ret = inputs + self.baseline
        return (ret)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2])

使用 SimpleLayer 创建一个 model

inp = Input(shape=(25,1))
x = SimpleLayer()(inp)
out = Dense(3)(x)
model = Model(inp, out)
model.summary()

摘要:

Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        [(None, 25, 1)]           0         
_________________________________________________________________
simple_layer_16 (SimpleLayer (None, 25, 1)             1         
_________________________________________________________________
dense_22 (Dense)             (None, 25, 3)             6         
=================================================================
Total params: 7
Trainable params: 7
Non-trainable params: 0

编辑

我尝试以这种方式覆盖无维度的问题

class SimpleLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(SimpleLayer,  self).__init__(**kwargs)
        self.baseline = tf.Variable(initial_value=0.1, trainable=True, dtype=tf.float64)

    def call(self, inputs):
        ret = tf.zeros((1, 25, 1), dtype=tf.float64) + self.baseline
        ret = tf.compat.v1.placeholder_with_default(ret, (None, 25, 1))
        return (ret)

inp = Input((150,1))
x = Dense(256)(inp)
x = SimpleLayer()(x)
x = Dense(10)(x)

model = Model(inp, x)
model.summary()

摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_34 (InputLayer)        [(None, 150, 1)]          0         
_________________________________________________________________
dense_68 (Dense)             (None, 150, 256)          512       
_________________________________________________________________
simple_layer_9 (SimpleLayer) (None, 25, 1)             1         
_________________________________________________________________
dense_69 (Dense)             (None, 25, 10)            20        
=================================================================
Total params: 533
Trainable params: 533
Non-trainable params: 0

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM