简体   繁体   中英

TensorFlow batch normalization, dimension error

I'm tensorflow beginner. I'm just trying to use Batch Normalization for improving MNIST accuracy upto 99.5%. I use CNN. But I've some problems.

with tf.name_scope('convolution_pooling_1'):
        phase = tf.placeholder(tf.bool, name='phase')
        W_conv1 = tf.Variable(tf.truncated_normal([3,3,1,num_filters1], stddev=0.1), name='conv_1_filter')
        h_conv1 = tf.nn.conv2d(
            x_image, W_conv1, strides=[1,1,1,1], padding='SAME',
            name='filter-output_1')
        bn1 = tf.contrib.layers.batch_norm(h_conv1, 
                                      center=True, scale=True, 
                                      is_training=phase)

        W_conv2 = tf.Variable(tf.truncated_normal([3,3,1,num_filters1], stddev=0.1), name='conv_2_filter')
        h_conv2 = tf.nn.conv2d(
            bn1, W_conv2, strides=[1,1,1,1], padding='SAME',
            name='filter-output_2')
        b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))
        h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2, name='conv_2_cutoff')    
        bn2 = tf.contrib.layers.batch_norm(h_conv2_cutoff, 
                                      center=True, scale=True, 
                                      is_training=phase)

This is my first part of CNN. 1. I wanna design the model by this way. 32filters Convolution - Batch Normalization - 32filters Convolution - Batch Normalization

But during the training, i got this error

ValueError: Dimensions must be equal, but are 32 and 1 for 'convolution_pooling_1/filter-output_2' (op: 'Conv2D') with input shapes: [?,28,28,32], [3,3,1,32].

Result of my Batch Normalization doesn't match the next convolution calc?

Please help me!

The problem isn't the batch normalization, but the filter shape of your second convolutional layer. You defined it to be [3,3,1,32] (see error message), but it needs to be [3,3,32,32] . Just correct the shape for your W_conv2 variable:

W_conv2 = tf.Variable(tf.truncated_normal([3,3,num_filters1,num_filters1], stddev=0.1),name='conv_2_filter')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM