繁体   English   中英

如何在 keras 中使用重塑层添加维度

[英]How to add a dimension using reshape layer in keras

我想在我的 model 中扩展尺寸。 我可以用tf.keras.layers.Reshape()层替换tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1))

我的 model 是

    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(10, activation='relu', input_shape=(i1,i2))),
    tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1)),
    tf.keras.layers.Dense(1)
    

我要更换 lambda 层

修改代码:

 tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(10, activation='relu'),input_shape=(i1,i2)),
tf.keras.layers.Reshape((1,)),
tf.keras.layers.Dense(1)

错误:

ValueError: Exception encountered when calling layer "reshape" (type Reshape).

total size of new array must be unchanged, input_shape = [20], output_shape = [1]

Call arguments received:
  • inputs=tf.Tensor(shape=(None, 20), dtype=float32)

也许是这样的(你不必照顾批量维度):

import tensorflow as tf

inputs = tf.keras.layers.Input((2, ))
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
outputs = tf.keras.layers.Reshape((1,) + x.shape[1:])(x)

model = tf.keras.Model(inputs, outputs)
model.summary()

使用您的代码:

import tensorflow as tf

inputs = tf.keras.layers.Input((5, 10))
x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(10, activation='relu'))(inputs)
x = tf.keras.layers.Reshape((1,) + x.shape[1:])(x)
outputs = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(inputs, outputs)
model.summary()

通常,如果您查看文档,如果您使用-1 ,则会推断出最后一层的 output 形状:

# also supports shape inference using `-1` as dimension
model.add(tf.keras.layers.Reshape((-1, 2, 2)))
# where 2 and 2 are the new dimensions and -1 is referring to the output shape of the last layer.

这基本上是有效的,因为Reshape 在内部调用tf.TensorShape

input_shape = tf.TensorShape(input_shape).as_list()

我个人更喜欢明确地调用形状。

我们可以像下面这样使用tf.keras.layers.Reshape((1, -1)) ,而不是使用tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1))

import tensorflow as tf

model = tf.keras.Sequential([
        tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(10, activation='relu', input_shape=[100, 256])),
        tf.keras.layers.Reshape((1, -1)),
        tf.keras.layers.Dense(10)
    ])
model(tf.random.uniform((1, 100, 256))) # (batch_dim, input.shape[0], input.shape[1])
model.summary()    

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bidirectional_1 (Bidirectio  (1, 20)                  21360     
 nal)                                                            
                                                                 
 reshape_1 (Reshape)         (1, 1, 20)                0         
                                                                 
 dense_1 (Dense)             (1, 1, 10)                210       
                                                                 
=================================================================
Total params: 21,570
Trainable params: 21,570
Non-trainable params: 0
_________________________________________________________________

检查您的代码,我们得到相同的结果:

import tensorflow as tf

model = tf.keras.Sequential([
        tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(10, activation='relu', input_shape=[100, 256])),
        tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1)),
        tf.keras.layers.Dense(10)
    ])
model(tf.random.uniform((1, 100, 256))) # (batch_dim, input.shape[0], input.shape[1])
model.summary()    

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bidirectional_2 (Bidirectio  (1, 20)                  21360     
 nal)                                                            
                                                                 
 lambda (Lambda)             (1, 1, 20)                0         
                                                                 
 dense_2 (Dense)             (1, 1, 10)                210       
                                                                 
=================================================================
Total params: 21,570
Trainable params: 21,570
Non-trainable params: 0
_________________________________________________________________

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM