[英]How to initialize a Conv2D layer with predetermined list of kernels in tensorflow/keras?
我想使用Conv2D
層來跨越輸入圖像並運行三個2x2
內核。
這不是 tensorflow 的目的,但我真的想使用 tensorflow 作為后端引擎來有效地運行內核並在不同設備的 GPU 和/或 CPU 之間分配工作負載。
我嘗試了類似下面的代碼。 但它似乎並沒有很好地工作。
import tensorflow as tf
class InitConvKernels(tf.keras.initializers.Initializer):
def __init__(self, num_kernels, kernel_tensor):
self.kernel_list= kernel_tensor
self.index = -1
self.num_kernels = num_kernels
def __call__(self, shape, dtype=None):
index += 1
assert(self.index <= self.num_kernels) # doesn't affect anything
tf.print(shape) # doesn't work
return self.kernel_list[index]
def get_config(self):
return {'kernel_list': self.kernel_list, 'num_kernels': self.num_kernels}
我正在調用自定義初始化程序,但返回的層是空的:
kernel_list = tf.constant([[[-1, -1], [-1, -1]], [[1, 1], [1, 1]], [[-1, 1], [1, -1]],])
layer = layers.Conv2D(
filters=3,
kernel_size=2,
kernel_initializer=InitConvKernels(3,kernel_list),
bias_initializer=initializers.Zeros()
)
layer.variables
是空的 ( []
) layer.layer.get_weights()
也是空的 ( []
)
我的目標是評估kernel_list
中三個內核在輸入圖像上的卷積並聚合所有結果。
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D
response = requests.get('https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/Stack_Overflow_logo.svg/1280px-Stack_Overflow_logo.svg.png')
image = Image.open(BytesIO(response.content))
構建 model 以運行kernel (運行更多內核使kernel_init
成為生成器,並在初始化Conv2D
時輕松調整過濾器的數量)
def kernel_init(shape, dtype=None, partition_info=None):
kernel = np.zeros(shape)
kernel[:,:,0,0] = np.array([[1,0,1],[-1,0,-1],[1,0,1]])
return kernel
#Build Keras model
model = Sequential()
model.add(Conv2D(1, [3,3], kernel_initializer=kernel_init,
input_shape=(251,1280,4), padding="valid"))
model.build()
# To apply existing filter, we use predict with no training
out = model.predict(image)
並可視化 output:
import matplotlib.pyplot as plt
plt.matshow(out[0,:,:,0])
編輯:值得一提的是 OpenAI 的 Triton ,它可以幫助使用更高級別的語言和框架,例如 pytorch 來運行高效的 GPU 代碼:
類似 Python 的編程語言使沒有 CUDA 經驗的研究人員能夠編寫高效的 GPU 代碼 — 大多數時間與專家能夠產生的代碼相當。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.