簡體   English   中英

如何在 tensorflow/keras 中使用預定的內核列表初始化 Conv2D 層?

[英]How to initialize a Conv2D layer with predetermined list of kernels in tensorflow/keras?

我想使用Conv2D層來跨越輸入圖像並運行三個2x2內核。

這不是 tensorflow 的目的,但我真的想使用 tensorflow 作為后端引擎來有效地運行內核並在不同設備的 GPU 和/或 CPU 之間分配工作負載。

我嘗試了類似下面的代碼。 但它似乎並沒有很好地工作。

import tensorflow as tf

class InitConvKernels(tf.keras.initializers.Initializer):

  def __init__(self, num_kernels, kernel_tensor):
    self.kernel_list= kernel_tensor
    self.index = -1
    self.num_kernels = num_kernels

  def __call__(self, shape, dtype=None):
    index += 1 
    assert(self.index <= self.num_kernels) # doesn't affect anything
    tf.print(shape) # doesn't work
    return self.kernel_list[index]

  def get_config(self):
    return {'kernel_list': self.kernel_list, 'num_kernels': self.num_kernels}

我正在調用自定義初始化程序,但返回的層是空的:

kernel_list = tf.constant([[[-1, -1],  [-1, -1]], [[1, 1],   [1, 1]],  [[-1, 1],  [1, -1]],])
layer = layers.Conv2D(
    filters=3,
    kernel_size=2,
    kernel_initializer=InitConvKernels(3,kernel_list),
    bias_initializer=initializers.Zeros()
)

layer.variables是空的 ( [] ) layer.layer.get_weights()也是空的 ( [] )

我的目標是評估kernel_list中三個內核在輸入圖像上的卷積並聚合所有結果。

from PIL import Image
import requests
from io import BytesIO
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D


response = requests.get('https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/Stack_Overflow_logo.svg/1280px-Stack_Overflow_logo.svg.png')
image = Image.open(BytesIO(response.content))

從 url 加載圖像 在此處輸入圖像描述

構建 model 以運行kernel (運行更多內核使kernel_init成為生成器,並在初始化Conv2D時輕松調整過濾器的數量)

def kernel_init(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[:,:,0,0] = np.array([[1,0,1],[-1,0,-1],[1,0,1]])
    return kernel

#Build Keras model
model = Sequential()
model.add(Conv2D(1, [3,3], kernel_initializer=kernel_init, 
                 input_shape=(251,1280,4), padding="valid"))
model.build()

# To apply existing filter, we use predict with no training
out = model.predict(image)

並可視化 output:

import matplotlib.pyplot as plt
plt.matshow(out[0,:,:,0])

在此處輸入圖像描述

編輯:值得一提的是 OpenAI 的 Triton ,它可以幫助使用更高級別的語言和框架,例如 pytorch 來運行高效的 GPU 代碼:

類似 Python 的編程語言使沒有 CUDA 經驗的研究人員能夠編寫高效的 GPU 代碼 — 大多數時間與專家能夠產生的代碼相當。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM