简体   繁体   中英

How to (efficiently) apply a channel-wise fully connected layer in TensorFlow

I'm coming to you again scratching my head at something that I can get to work, but really slowly. I hope you can help me to optimize it.

I'm trying to implement a convolutional auto-encoder in TensorFlow with a big latent space between encoder and decoder. Usually, one would connect the encoder to the decoder with a fully connected layer, but because this latent space has a high dimensionality, doing so would create too many features for it to be computationally feasible.

I found a nice solution to this problem in this paper . They call it 'channel-wise fully connected layer'. It's basically a fully connected layer per channel.

I'm working on the implementation and I got it to work, but the generation of the graph takes a long time. This is my code so far:

def _network(self, dataset, isTraining):
        encoded = self._encoder(dataset, isTraining)
        with tf.variable_scope("fully_connected_channel_wise"):
            shape = encoded.get_shape().as_list()
            print(shape)
            channel_wise = tf.TensorArray(dtype=tf.float32, size=(shape[-1]))
            for i in range(shape[-1]):  # last index in shape should be the output channels of the last conv
                channel_wise = channel_wise.write(i, self._linearLayer(encoded[:,:,i], shape[1], shape[1]*4, 
                                  name='Channel-wise' + str(i), isTraining=isTraining))
            channel_wise = channel_wise.concat()
            reshape = tf.reshape(channel_wise, [shape[0], shape[1]*4, shape[-1]])
        reconstructed = self._decoder(reshape, isTraining)
        return reconstructed

So, any ideas as to why this is taking so long? that is a range(2048) in practice, but all the linear layers are really small (4x16). Am I approaching this the wrong way?

Thanks!

You can check their implementation of that paper in Tensorflow. Here is their implementation of 'channel-wise fully connected layer'.

def channel_wise_fc_layer(self, input, name): # bottom: (7x7x512)
    _, width, height, n_feat_map = input.get_shape().as_list()
    input_reshape = tf.reshape( input, [-1, width*height, n_feat_map] )
    input_transpose = tf.transpose( input_reshape, [2,0,1] )

    with tf.variable_scope(name):
        W = tf.get_variable(
                "W",
                shape=[n_feat_map,width*height, width*height], # (512,49,49)
                initializer=tf.random_normal_initializer(0., 0.005))
        output = tf.batch_matmul(input_transpose, W)

    output_transpose = tf.transpose(output, [1,2,0])
    output_reshape = tf.reshape( output_transpose, [-1, height, width, n_feat_map] )

    return output_reshape

https://github.com/jazzsaxmafia/Inpainting/blob/8c7735ec85393e0a1d40f05c11fa1686f9bd530f/src/model.py#L60

The main idea is using tf.batch_matmul function.

However, the tf.batch_matmul is removed in the newest version of Tensorflow, you may use tf.matmul to replace it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM