簡體   English   中英

理解 tf.nn.depthwise_conv2d

[英]Understanding tf.nn.depthwise_conv2d

來自https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d

給定一個 4D 輸入張量('NHWC' 或 'NCHW' 數據格式)和一個形狀為 [filter_height, filter_width, in_channels, channel_multiplier] 的濾波器張量,包含深度為 1 的 in_channels 卷積濾波器,depthwise_conv2d 對每個輸入通道應用不同的濾波器(擴展從 1 個通道到每個通道的 channel_multiplier 通道),然后將結果連接在一起。 輸出有 in_channels * channel_multiplier 通道

  1. “從 1 個通道擴展到每個通道的 channel_multiplier 通道”是什么意思?
  2. 是否有可能有 out_channels < in_channels?
  3. 是否可以像在 Pytorch https://pytorch.org/docs/stable/nn.html#conv2d 中那樣將輸入張量划分為組?

例子:

import tensorflow as tf
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

np.random.seed(2020)

print('tf.__version__', tf.__version__)

def get_data_batch():
    bs = 2
    h = 3
    w = 3
    c = 4

    x_np = np.random.rand(bs, h, w, c)
    x_np = x_np.astype(np.float32)
    print('x_np.shape', x_np.shape)

    return x_np


def run_conv_dw():
    print('='*60)
    x_np = get_data_batch()
    in_channels = x_np.shape[-1]
    kernel_size = 3
    channel_multiplier = 1
    with tf.Session() as sess:
        x_tf = tf.convert_to_tensor(x_np)
        filter = tf.get_variable('w1', [kernel_size, kernel_size, in_channels, channel_multiplier],
                                 initializer=tf.contrib.layers.xavier_initializer())
        z_tf = tf.nn.depthwise_conv2d(x_tf, filter=filter, strides=[1, 1, 1, 1], padding='SAME')

        sess.run(tf.global_variables_initializer())
        z_np = sess.run(fetches=[z_tf], feed_dict={x_tf: x_np})[0]
        print('z_np.shape', z_np.shape)


if '__main__' == __name__:
    run_conv_dw()

通道乘數不能浮動:

如果channel_multiplier = 1

x_np.shape (2, 3, 3, 4)
z_np.shape (2, 3, 3, 4)

如果channel_multiplier = 2

x_np.shape (2, 3, 3, 4)
z_np.shape (2, 3, 3, 8)

在pytorch術語中:

  1. 每組總是一個輸入通道,每組'channel_multiplier'輸出通道;
  2. 不是一步;
  3. 見 1

我看到了一種模擬每組多個輸入通道的方法。 對於兩個,執行depthwise_conv2d ,然后將結果張量作為一副牌分成兩半,然后按元素求和獲得的一半(在 relu 等之前)。 請注意,輸入通道編號i將與i+inputs/2組合在一起。


編輯:上面的技巧對於小團體很有用,對於大團體來說,只需將輸入張量拆分為 N 個部分,其中 N 是組數,每個獨立地制作conv2d ,然后連接結果。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM