[英]Get batch size in Keras custom layer and use tensorflow operations (tf.Variable)
I would like to write a Keras custom layer with tensorflow operations, that require the batch size as input.我想用张量流操作编写一个 Keras 自定义层,它需要批量大小作为输入。 Apparently I'm struggling in every nook and cranny.
显然,我在每一个角落都在挣扎。
Suppose a very simple layer: (1) get batch size (2) create a tf.Variable (let's call it my_var) based on the batch size, then some tf.random ops to alter my_var (3) finally, return input multiplied with my_var假设一个非常简单的层:(1)获取批量大小(2)根据批量大小创建一个 tf.Variable(我们称之为 my_var),然后一些 tf.random ops 来改变 my_var (3)最后,返回输入乘以我的_var
What I tried so far:到目前为止我尝试过的:
class TestLayer(Layer):
def __init__(self, **kwargs):
self.num_batch = None
self.my_var = None
super(TestLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_size = input_shape[0]
var_init = tf.ones(self.batch_size, dtype = x.dtype)
self.my_var = tf.Variable(var_init, trainable=False, validate_shape=False)
# some tensorflow random operations to alter self.my_var
super(TestLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return self.my_var * x
def compute_output_shape(self, input_shape):
return input_shape
Now creating a very simple model:现在创建一个非常简单的模型:
# define model
input_layer = Input(shape = (2, 2, 3), name = 'input_layer')
x = TestLayer()(input_layer)
# connect model
my_mod = Model(inputs = input_layer, outputs = x)
my_mod.summary()
Unfortunately, what ever I try/change in the code, I get multiple errors, most of them with very cryptical tracebacks (ValueError: Cannot convert a partially known TensorShape to a Tensor: or ValueError: None values not supported.).不幸的是,无论我在代码中尝试/更改什么,我都会遇到多个错误,其中大多数带有非常神秘的回溯(ValueError:无法将部分已知的 TensorShape 转换为 Tensor:或 ValueError:不支持的值。)。
Any general suggestions?有什么一般建议吗? Thanks in advance.
提前致谢。
You need to specify batch size if you want to create a variable of size batch_size
.如果要创建大小为
batch_size
的变量,则需要指定批处理大小。 Additionally, if you want to print a summary the tf.Variable
must have a fixed shape ( validatate_shape=True
) and it must be broadcastable to be successfully multiplied by the input:此外,如果要打印摘要,
tf.Variable
必须具有固定形状( validatate_shape=True
)并且它必须是可广播的才能成功乘以输入:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input
from tensorflow.keras.models import Model
class TestLayer(Layer):
def __init__(self, **kwargs):
self.num_batch = None
self.my_var = None
super(TestLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_size = input_shape[0]
var_init = tf.ones(self.batch_size, dtype=tf.float32)[..., None, None, None]
self.my_var = tf.Variable(var_init, trainable=False, validate_shape=True)
super(TestLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
res = self.my_var * x
return res
def compute_output_shape(self, input_shape):
return input_shape
# define model
input_layer = Input(shape=(2, 2, 3), name='input_layer', batch_size=10)
x = TestLayer()(input_layer)
# connect model
my_mod = Model(inputs=input_layer, outputs=x)
my_mod.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) (10, 2, 2, 3) 0
_________________________________________________________________
test_layer (TestLayer) (10, 2, 2, 3) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.