簡體   English   中英

Tensorflow 在不同名稱范圍內重用變量

[英]Tensorflow reuse variables in different name scope

我遇到了在不同名稱范圍內重用變量的問題。 下面的代碼將源嵌入和目標嵌入在兩個不同的空間中分開,我想要做的是將源和目標放在同一個空間中,重用查找表中的變量。

''' Applying bidirectional encoding for source-side inputs and first-word decoding.
'''
def decode_first_word(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Source_Side'):
            source_embedding_tensor = self._src_lookup_table(source_vocab_id_tensor)
    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor]


''' Applying one-step decoding.
'''
def decode_next_word(self, enc_concat_hidden, src_mask, cur_dec_hidden, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Target_Side'):
            cur_trg_wemb = None 
            if None == cur_trg_wid:
                pass
            else:
                cur_trg_wemb = self._trg_lookup_table(cur_trg_wid)

我想按如下方式制作它們,因此整個圖中只有一個嵌入節點:

def decode_first_word_shared_embedding(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Bi_Side'):
            source_embedding_tensor = self._bi_lookup_table(source_vocab_id_tensor)
    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word_shared_embedding(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor]

def decode_next_word_shared_embedding(self, enc_concat_hidden, src_mask, cur_dec_hidden, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.name_scope('Word_Embedding_Layer'):            
        cur_trg_wemb = None 
        if None == cur_trg_wid:
            pass
        else:
            with tf.variable_scope('Bi_Side'):
                cur_trg_wemb = self._bi_lookup_table(cur_trg_wid)

如何實現這一目標?

我通過使用字典來保存嵌入的權重矩陣來解決它。 來自https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/的提示

解決方案之一是保存 variable_scope 實例並重用它。


def decode_first_word_shared_embedding(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Bi_Side'):
            source_embedding_tensor = self._bi_lookup_table(source_vocab_id_tensor)
            shared_variable_scope = tf.get_variable_scope()

    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word_shared_embedding(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor], 

def decode_next_word_shared_embedding(self, enc_concat_hidden, src_mask, cur_dec_hidden, shared_variable_scope, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.variable_scope('Target_Side'):           
        cur_trg_wemb = None 
        if None == cur_trg_wid:
            pass
        else:
            with tf.variable_scope(shared_variable_scope, reuse=True):
                cur_trg_wemb = self._bi_lookup_table(cur_trg_wid)

這是我的演示代碼:

with tf.variable_scope('Word_Embedding_Layer'):
    with tf.variable_scope('Bi_Side'):
        v = tf.get_variable('bi_var', [1], dtype=tf.float32)
        reuse_scope = tf.get_variable_scope()
with tf.variable_scope('Target_side'):
    # some other codes.
    with tf.variable_scope(reuse_scope, reuse=True):
        w = tf.get_variable('bi_var', [1], dtype=tf.float32)
print(v.name)
print(w.name)
assert v==w

Output:
Word_Embedding_Layer/Bi_Side/bi_var:0
Word_Embedding_Layer/Bi_Side/bi_var:0

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM