[英]Tensorflow implementation of word2vec
該Tensorflow教程這里指的是他們的基本實現,你可以在github上找到這里 ,這里的Tensorflow作者實現word2vec矢量嵌入培訓/評估與Skipgram模型。
我的問題是關於generate_batch()
函數中(target,context)對的實際生成。
在這一行上, Tensorflow作者從單詞滑動窗口中的“中心”單詞索引中隨機抽取附近的目標索引。
但是,它們還保留了一個數據結構targets_to_avoid
,它們首先添加“中心”上下文單詞(當然我們不想采樣),但在添加它們之后還要添加其他單詞。
我的問題如下:
word2vec_basic.py
(他們的“基本”實現)中的性能/內存似乎很奇怪。 targets_to_avoid
選擇的targets_to_avoid
? 如果他們想要真正隨機,他們會使用替換選擇,如果他們想確保他們獲得所有選項,他們應該只使用一個循環並首先獲得所有選項! 謝謝!
我嘗試了你提出的生成批次的方法 - 有一個循環並使用整個跳過窗口。 結果是:
1.更快地生成批次
批量大小為128,跳過窗口為5
num_skips=2
每10,000批次需要3.59s 2.過度擬合的危險性更高
保持教程代碼的其余部分,我用兩種方式訓練模型並記錄每2000步的平均損失:
這種模式反復發生。 它表明,每個單詞使用10個樣本而不是2個樣本會導致過度擬合。
這是我用於生成批次的代碼。 它取代了教程的generate_batch
函數。
data_index = 0
def generate_batch(batch_size, skip_window):
global data_index
batch = np.ndarray(shape=(batch_size), dtype=np.int32) # Row
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) # Column
# For each word in the data, add the context to the batch and the word to the labels
batch_index = 0
while batch_index < batch_size:
context = data[get_context_indices(data_index, skip_window)]
# Add the context to the remaining batch space
remaining_space = min(batch_size - batch_index, len(context))
batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
labels[batch_index:batch_index + remaining_space] = data[data_index]
# Update the data_index and the batch_index
batch_index += remaining_space
data_index = (data_index + 1) % len(data)
return batch, labels
編輯: get_context_indices
是一個簡單的函數,它返回data_index周圍的skip_window中的索引切片。 有關詳細信息,請參閱slice()文檔 。
有一個名為num_skips
的參數,表示從單個窗口生成的(輸入,輸出)對的數量:[skip_window target skip_window]。 所以num_skips
限制了我們用作輸出詞的上下文詞的數量。 這就是為什么generate_batch函數assert num_skips <= 2*skip_window
。 代碼只是隨機選取num_skip
上下文單詞來構建帶目標的訓練對。 但我不知道num_skips
如何影響性能。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.