[英]How to have un-tracked weights in custom keras layer?
我想創建一個自定義 keras 層(VQVAE model 的代碼簿。)在訓練時,我想要一個tf.Variable
來跟蹤每個代碼的使用情況,以便我可以重新啟動未使用的代碼。 所以我創建了我的 Codebook 層如下......
class Codebook(layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False)
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
我遇到的問題是Layer
class 找到成員變量self.code_counter
並將其添加到與層一起保存的權重列表中。 它還期望在加載權重時出現self.code_counter
,而當我在推理模式下運行時,情況並非如此。 我怎樣才能使 keras 不跟蹤我層中的變量。 我不希望它持續存在或成為layers.weights
的一部分。
根據文檔:
設置為圖層屬性的變量作為圖層的權重進行跟蹤(在 layer.weights 中)
所以問題是您是否可以單獨使用tf.zeros
或與tf.constant
一起使用:
import tensorflow as tf
class Codebook(tf.keras.layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.constant(tf.zeros(num_codes, dtype = tf.int32))
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
code_book = Codebook(num_codes=5, code_reset_limit=True)
print(code_book.weights)
[]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.