繁体   English   中英

与 tensorflow 2 的深度互相关

[英]Depthwise cross-correlation with tensorflow 2

我想用 tensorflow 2 和 keras 实现SiamRPN++中描述的深度互相关层。 它应该是 keras 层的子类,以允许灵活使用。 我的实现编译正确,但在训练 tensorflow 时抛出错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:从形状为 [24,8,32] 的张量中指定形状为 [8,24,32] 的列表

这是我的代码。 我究竟做错了什么?

class CrossCorr(Layer):
"""
Implements the cross correlation laer of siam_rpn_plus_plus
"""

def __init__(self, **kwargs):
    super().__init__(**kwargs)

def build(self, inputs):
    super(CrossCorr, self).build(inputs)  # Be sure to call this at the end

def call(self, inputs, **kwargs):
    def _corr(search_img, filter):
        x = tf.expand_dims(search_img, 0)
        f = tf.expand_dims(filter, -1)
        # use the feature map as kernel for the depthwise conv2d of tensorflow
        return tf.nn.depthwise_conv2d(input=x, filter=f, strides=[1, 1, 1, 1], padding='SAME')
    
    # Iteration over each batch
    out = tf.map_fn(
        lambda filter_simg: _corr(filter_simg[0], filter_simg[1]),
        elems=inputs,
        dtype=inputs[0].dtype
    )
    return tf.squeeze(out, [1])

def compute_output_shape(self, input_shape):
    return input_shape

要调用它,我使用:

def _conv_block(inputs, filters, kernel, strides, kernel_regularizer=None):
    x = Conv2D(filters, kernel, padding='same', strides=strides, 
    kernel_regularizer=kernel_regularizer)(inputs)
    x = BatchNormalization()(x)
    return Activation(relu)(x)


def cross_correlation_layer(search_img, template_img, n_filters=None):
    n_filters = int(search_img.shape[-1]) if n_filters is None else n_filters

    tmpl = _conv_block(template_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))
    search = _conv_block(search_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))

    # calculate cross correlation by striding the generated "filter" over the image in depthwise manner
    cc = CrossCorr()([search, tmpl])
    # 1D conv to make it a seperable convolution
    cc = Conv2D(filters=n_filters, kernel_size=1, strides=1)(cc)
    # apply one more filter over it
    fusion = _conv_block(cc, n_filters, 3, 1)
    return fusion

在尝试运行您的代码后,我意识到该层(您通过 _conv_block 调用)需要一组批处理图像,以下是您的 cross_correlation_layer function 中的上述修改

def cross_correlation_layer(search_img, template_img, n_filters=None):
    n_filters = int(search_img.shape[-1]) if n_filters is None else n_filters

    template_img = tf.expand_dims(template_img, -1)
    search_img = tf.expand_dims(search_img, -1)

    tmpl = _conv_block(template_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))
    search = _conv_block(search_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))

您的 function 的 rest 保持不变,希望这对您有所帮助,(与您的问题无关,但“过滤器”一词是 Python 中的保留关键字,我们尽量避免将其用作参数名称)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM