簡體   English   中英

如何在 Keras 中實現高斯模糊層?

[英]how do I implement Gaussian blurring layer in Keras?

我有一個自動編碼器,我需要在輸出后添加一個高斯噪聲層。 我需要一個自定義層來執行此操作,但我真的不知道如何生成它,我需要使用張量生成它。 在此處輸入圖片說明

如果我想在下面代碼的調用部分實現上面的等式,我該怎么做?

class SaltAndPepper(Layer):

    def __init__(self, ratio, **kwargs):
        super(SaltAndPepper, self).__init__(**kwargs)
        self.supports_masking = True
        self.ratio = ratio

    # the definition of the call method of custom layer
    def call(self, inputs, training=None):
        def noised():
            shp = K.shape(inputs)[1:]

         **what should I put here????**            
                return out

        return K.in_train_phase(noised(), inputs, training=training)

    def get_config(self):
        config = {'ratio': self.ratio}
        base_config = super(SaltAndPepper, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

我也嘗試使用 lambda 層來實現,但它不起作用。

如果您正在尋找加法乘法高斯噪聲,那么它們已經在 Keras 中實現為一個層: GuassianNoise (加法)和GuassianDropout (乘法)。

但是,如果您在圖像處理中專門尋找高斯模糊過濾器中的模糊效果,那么您可以簡單地使用具有固定權重的深度卷積層(在每個輸入通道上獨立應用過濾器)來獲得所需的輸出(請注意,您需要生成高斯核的權重以將它們設置為 DepthwiseConv2D 層的權重。為此,您可以使用本答案中介紹的函數):

import numpy as np
from keras.layers import DepthwiseConv2D

kernel_size = 3  # set the filter size of Gaussian filter
kernel_weights = ... # compute the weights of the filter with the given size (and additional params)

# assuming that the shape of `kernel_weighs` is `(kernel_size, kernel_size)`
# we need to modify it to make it compatible with the number of input channels
in_channels = 3  # the number of input channels
kernel_weights = np.expand_dims(kernel_weights, axis=-1)
kernel_weights = np.repeat(kernel_weights, in_channels, axis=-1) # apply the same filter on all the input channels
kernel_weights = np.expand_dims(kernel_weights, axis=-1)  # for shape compatibility reasons

# define your model...

# somewhere in your model you want to apply the Gaussian blur,
# so define a DepthwiseConv2D layer and set its weights to kernel weights
g_layer = DepthwiseConv2D(kernel_size, use_bias=False, padding='same')
g_layer_out = g_layer(the_input_tensor_for_this_layer)  # apply it on the input Tensor of this layer

# the rest of the model definition...

# do this BEFORE calling `compile` method of the model
g_layer.set_weights([kernel_weights])
g_layer.trainable = False  # the weights should not change during training

# compile the model and start training...

經過一段時間試圖弄清楚如何使用@today 提供的代碼執行此操作后,我決定與將來可能需要它的任何人共享我的最終代碼。 我創建了一個非常簡單的模型,它只對輸入數據應用模糊處理:

import numpy as np
from keras.layers import DepthwiseConv2D
from keras.layers import Input
from keras.models import Model


def gauss2D(shape=(3,3),sigma=0.5):

    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    return h

def gaussFilter():
    kernel_size = 3
    kernel_weights = gauss2D(shape=(kernel_size,kernel_size))
    
    
    in_channels = 1  # the number of input channels
    kernel_weights = np.expand_dims(kernel_weights, axis=-1)
    kernel_weights = np.repeat(kernel_weights, in_channels, axis=-1) # apply the same filter on all the input channels
    kernel_weights = np.expand_dims(kernel_weights, axis=-1)  # for shape compatibility reasons
    
    
    inp = Input(shape=(3,3,1))
    g_layer = DepthwiseConv2D(kernel_size, use_bias=False, padding='same')(inp)
    model_network = Model(input=inp, output=g_layer)
    model_network.layers[1].set_weights([kernel_weights])
    model_network.trainable= False #can be applied to a given layer only as well
        
    return model_network

a = np.array([[[1, 2, 3], [4, 5, 6], [4, 5, 6]]])
filt = gaussFilter()
print(a.reshape((1,3,3,1)))
print(filt.predict(a.reshape(1,3,3,1)))

出於測試目的,數據只有1,3,3,1的形狀,函數gaussFilter()創建了一個非常簡單的模型,只有輸入和一個卷積層,提供高斯模糊,權gauss2D()函數gauss2D()定義。 您可以向函數添加參數以使其更具動態性,例如形狀、內核大小、通道。 只有在將層添加到模型后才能應用根據我的發現的權重。

由於 Error: AttributeError: 'float' object has no attribute 'dtype' K.sqrt AttributeError: 'float' object has no attribute 'dtype' ,只需將K.sqrt更改為math.sqrt ,它就會起作用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM