簡體   English   中英

我如何解釋這個 TensorFlow tf.nn.conv2d() 層形狀?

[英]How do I explain this TensorFlow tf.nn.conv2d() layer shape?

我的 Tensorflow 卷積層有一個我沒想到的形狀,我沒有看到錯誤。

我是 TensorFlow 的新手,想使用這個函數來創建一個卷積層:

def new_conv_layer(input,              # The previous layer.
                   num_input_channels, # Num. channels in prev. layer.
                   filter_size,        # Width and height of each filter.
                   num_filters,        # Number of filters.
                   use_pooling=True):  # Use 2x2 max-pooling.

    shape = [filter_size, filter_size, num_input_channels, num_filters]

    weights = new_weights(shape=shape)

    biases = new_biases(length=num_filters)

    layer = tf.nn.conv2d(input=input_,
                         filters=weights,
                         strides=[1, 1, 1, 1],
                         padding='SAME')

    layer += biases

    if use_pooling:

        layer = tf.nn.max_pool(input=layer,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME')


    layer = tf.nn.relu(layer)


    return layer, weights

但是當我使用它時

num_channels = 1
img_size = 28

x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])

# Convolutional Layer 1.
filter_size1 = 5          # Convolution filters are 5 x 5 pixels.
num_filters1 = 16         # There are 16 of these filters.

layer_conv1, weights_conv1 = new_conv_layer(input=x_image,
                                           num_input_channels=num_channels,
                                           filter_size=filter_size1,
                                           num_filters=num_filters1,
                                           use_pooling=True)

layer_conv1

我得到這個輸出:

<tf.Tensor 'Relu:0' shape=(None, 392, 392, 16) dtype=float32>

因為我的圖像是方形 28x28 形狀並且我應用了 2x2 池化,所以我希望這個形狀是 (None, 14, 14, 16)。

為什么不是這種情況,我該如何解決?

在我的情況下,這一行x = tf.compat.v1.placeholder(tf.float32, shape=[None, img_size_flat], name='x')不正確!

特別是img_size_flat不是每個“拉伸”圖像的長度,因為它應該是。

img_size_flat = df.drop('label', axis=1).shape[1]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM