[英]Randomly select layer in tensorflow model
我想在我的網絡中使用具有特定概率的不同層。 層是以下類。
class plus1(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, X):
return X + 1
def compute_output_shape(self, batch_input_shape):
return batch_input_shape
class plus2(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, X):
return X + 2
def compute_output_shape(self, batch_input_shape):
return batch_input_shape
class plus3(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, X):
return X + 3
def compute_output_shape(self, batch_input_shape):
return batch_input_shape
網絡如下圖。
def f1():
return plus1()
def f2():
return plus2()
def f3():
return plus3()
def simple_model(input_num):
input_layer = Input(input_num)
rand = tf.random.uniform((1,), minval=0, maxval=3, dtype=tf.int32)
r = tf.switch_case(rand[0], branch_fns={0: f1, 1: f2, 2: f3})
res = r(input_layer)
model = Model(inputs=input_layer, outputs=res)
return model
model = simple_model([1,])
每次我運行下面的代碼時,我都會得到相同的 output,但我預計會有不同的。 有什么方法可以實現嗎?
model.predict([1])
>>> array([[4.]], dtype=float32)
這是我面臨的同樣問題,但我沒有找到解決方案。 所以我實現了不同的網絡,然后從他們的 output 中隨機選擇。
我一直在處理同樣的問題:我有一個層列表,我需要在每次迭代時隨機從其中 select 。 tf.switch_case()
給了我你描述的同樣的問題。
無論出於何種原因,我沒有足夠的背景深度來告訴你為什么(我的tf.switch_case
實現完全有可能以一種不相關的方式出現錯誤),這段代碼對我有用:
def random_layer(layers, image_tensor):
"""
Selects and executes a random layer chosen from a list
"""
to_use = tf.random.uniform(shape=[], maxval=len(layers), dtype=tf.int32)
out = image_tensor
for i, layer in enumerate(layers):
# out is either image_tensor or the actual output, *but*
# since we can't break this loop, when it matches it will become the actual output
# and any further calls will return that value
def _match():
# tf.print("using {}".format(layer))
return layer(out, training=True)
out = tf.cond(to_use==i, _match, lambda: out)
return out
(請注意,我使用的是本地 function 只是為了驗證隨機性。)然后我傳入:
NOISE_LAYERS = [tf.keras.layers.GaussianNoise(stddev=.1),
tf.keras.layers.GaussianNoise(stddev=.2),
tf.keras.layers.GaussianNoise(stddev=.3),
tf.keras.layers.GaussianNoise(stddev=.4)]
(這是數據集准備的一部分,我希望圖像包含不同數量的噪聲。)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.