簡體   English   中英

如何在自定義 keras 層中有未跟蹤的權重?

[英]How to have un-tracked weights in custom keras layer?

我想創建一個自定義 keras 層(VQVAE model 的代碼簿。)在訓練時,我想要一個tf.Variable來跟蹤每個代碼的使用情況,以便我可以重新啟動未使用的代碼。 所以我創建了我的 Codebook 層如下......

class Codebook(layers.Layer): 
     def __init__(self, num_codes, code_reset_limit = None, **kwargs): 
         super().__init__(**kwargs) 
         self.num_codes = num_codes 
         self.code_reset_limit = code_reset_limit 
         if self.code_reset_limit: 
             self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) 
     def build(self, input_shape): 
         self.codes = self.add_weight(name = 'codes',  
                                      shape = (self.num_codes, input_shape[-1]), 
                                      initializer = 'random_uniform',  
                                      trainable = True) 
         super().build(input_shape) 
                                                                                                             

我遇到的問題是Layer class 找到成員變量self.code_counter並將其添加到與層一起保存的權重列表中。 它還期望在加載權重時出現self.code_counter ,而當我在推理模式下運行時,情況並非如此。 我怎樣才能使 keras 不跟蹤我層中的變量。 我不希望它持續存在或成為layers.weights的一部分。

根據文檔

設置為圖層屬性的變量作為圖層的權重進行跟蹤(在 layer.weights 中)

所以問題是您是否可以單獨使用tf.zeros或與tf.constant一起使用:

import tensorflow as tf

class Codebook(tf.keras.layers.Layer): 
     def __init__(self, num_codes, code_reset_limit = None, **kwargs): 
         super().__init__(**kwargs) 
         self.num_codes = num_codes 
         self.code_reset_limit = code_reset_limit 
         if self.code_reset_limit: 
            self.code_counter = tf.constant(tf.zeros(num_codes, dtype = tf.int32))

     def build(self, input_shape): 
         self.codes = self.add_weight(name = 'codes',  
                                      shape = (self.num_codes, input_shape[-1]), 
                                      initializer = 'random_uniform',  
                                      trainable = True) 
         super().build(input_shape) 
code_book = Codebook(num_codes=5, code_reset_limit=True)
print(code_book.weights)
[]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM