簡體   English   中英

Tensorflow實現word2vec

[英]Tensorflow implementation of word2vec

該Tensorflow教程這里指的是他們的基本實現,你可以在github上找到這里 ,這里的Tensorflow作者實現word2vec矢量嵌入培訓/評估與Skipgram模型。

我的問題是關於generate_batch()函數中(target,context)對的實際生成。

這一行上, Tensorflow作者從單詞滑動窗口中的“中心”單詞索引中隨機抽取附近的目標索引。

但是,它們還保留了一個數據結構targets_to_avoid ,它們首先添加“中心”上下文單詞(當然我們不想采樣),但在添加它們之后還要添加其他單詞。

我的問題如下:

  1. 為什么不從這個滑動窗口圍繞這個詞進行采樣,為什么不只是有一個循環並使用它們而不是采樣? 他們會擔心word2vec_basic.py (他們的“基本”實現)中的性能/內存似乎很奇怪。
  2. 無論1)的答案是什么,為什么他們采樣並跟蹤他們用targets_to_avoid選擇的targets_to_avoid 如果他們想要真正隨機,他們會使用替換選擇,如果他們想確保他們獲得所有選項,他們應該只使用一個循環並首先獲得所有選項!
  3. 內置的tf.models.embedding.gen_word2vec也是這樣工作的嗎? 如果是這樣我在哪里可以找到源代碼? (在Github repo中找不到.py文件)

謝謝!

我嘗試了你提出的生成批次的方法 - 有一個循環並使用整個跳過窗口。 結果是:

1.更快地生成批次

批量大小為128,跳過窗口為5

  • 通過逐個循環數據生成批次每10,000批次需要0.73秒
  • 使用教程代碼生成批次並且num_skips=2每10,000批次需要3.59s

2.過度擬合的危險性更高

保持教程代碼的其余部分,我用兩種方式訓練模型並記錄每2000步的平均損失:

在此輸入圖像描述

這種模式反復發生。 它表明,每個單詞使用10個樣本而不是2個樣本會導致過度擬合。

這是我用於生成批次的代碼。 它取代了教程的generate_batch函數。

data_index = 0

def generate_batch(batch_size, skip_window):
    global data_index
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)  # Row
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  # Column

    # For each word in the data, add the context to the batch and the word to the labels
    batch_index = 0
    while batch_index < batch_size:
        context = data[get_context_indices(data_index, skip_window)]

        # Add the context to the remaining batch space
        remaining_space = min(batch_size - batch_index, len(context))
        batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
        labels[batch_index:batch_index + remaining_space] = data[data_index]

        # Update the data_index and the batch_index
        batch_index += remaining_space
        data_index = (data_index + 1) % len(data)

    return batch, labels

編輯: get_context_indices是一個簡單的函數,它返回data_index周圍的skip_window中的索引切片。 有關詳細信息,請參閱slice()文檔

有一個名為num_skips的參數,表示從單個窗口生成的(輸入,輸出)對的數量:[skip_window target skip_window]。 所以num_skips限制了我們用作輸出詞的上下文詞的數量。 這就是為什么generate_batch函數assert num_skips <= 2*skip_window 代碼只是隨機選取num_skip上下文單詞來構建帶目標的訓練對。 但我不知道num_skips如何影響性能。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM