繁体   English   中英

获得 LSTM 的 top 3 预测,而不是仅获得 top

[英]Get top 3 prediction of LSTM instead of only the top

我有一个针对文本内容训练的 LSTM 模型。 现在我想用那个模型来生成一些句子。 但我不是总是选择最佳选项,而是希望它从例如前 3 个选项中进行选择,以便它可以使用相同的输入生成不同的句子,因为现在我几乎每个输入都得到相同的答案。 如何修改此代码以便可能,我知道我需要删除np.argmax但我不知道如何返回前 3 个最高值的索引。

当前代码:

def prediction(seed_text, next_words): 
  for _ in range(next_words):
    token_list = tokenizer.texts_to_sequences([seed_text])[0]
    token_list = pad_sequences([token_list], maxlen=max_seq_length-1, padding='pre')
    predicted = np.argmax(model.predict(token_list, verbose=0), axis=-1)
    ouput_word = ""
    for word, index in tokenizer.word_index.items():
      if index == predicted:
        output_word = word
        break
  
    seed_text += ' '+output_word
  return seed_text

np.argsort将按照从小到大的顺序为您提供数组中项目的索引: https ://numpy.org/doc/stable/reference/generated/numpy.argsort.html

这是一个使用argsort的示例。 请注意,预测值最低的那个(索引 2,“c”,预测值为 0.05)被排除在打印的内容之外。

import numpy as np

word_index = {'a': 0, 'b': 1, 'c': 2, 'd': 3}

predictions = np.array([0.1, 0.7, 0.05, 0.15])

# add negative to sort large to small; slice to select just up to 3rd index
top_3 = np.argsort(-predictions)[:3]

for word, index in word_index.items():
    if index in top_3:
        print(word)
#> a
#> b
#> d

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM