简体   繁体   中英

How to Apply a TimeDistributed Dense Layer to a tensor of shape (batch_size, 600, 105, 8) to produce output (batch_size, 600, 48)

I have a Conv2D layer producing a tensor of shape (batch_size, 600, 105, 8) . This is a batch of song spectrograms with a feature_map of 8. Now I want to apply a "Dense" layer of size 48 to each time frame (600) to produce a tensor of shape (batch_size, 600, 48) . The default Keras Dense layer doesn't seem to cut it...

Any suggestions?

This is my function

def build_cnn(input_shape=(None, None, 1),
          feature_map_size=8,
          num_layers=5,
          kernerl_size=(5, 5),
          dropout=0.2,
          pool_size=(2, 2),
          epochs=100,
          lr=0.001,
          momentum=0.9,
          verbose=False):

    model = Sequential()

    # Add the convolutional layers
    for _ in range(num_layers):     
        # Conv layer
        model.add(Conv2D(
            feature_map_size,
            kernerl_size,
            input_shape=input_shape,
            padding='same',
            activation='elu')
        )
        # Dropout layer
        # model.add(Dropout(dropout))

    # Dense layer
    model.add(TimeDistributed(Dense(48, activation='elu')))

The most basic and insensitive way of doing it, is reshaping the data:

#after the convolutions:

model.add(Reshape((600,105*8)))
model.add(Dense(48,...))

But this may not be the best choice. Unfortunately I don't understand those spectograms very well so suggest other approaches. But this one basically throws all bins and features in a Dense layer, and they will be mixed together.


In case reshaping with -1 fails

There is a possibility of handling with unknown shapes inside lambda layers with backend functions. But this only works with tensorflow as backend. Theano doesn't like the idea. I gave up theano because of this.

#valid in tensorflow only
import keras.backend as K

def reshape(x):
    shp = K.shape(x)
    shp = K.concatenate([shp[:2],shp[2:3]*shp[3:]])
    return K.reshape(x,shp)

model.add(Lambda(reshape))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM