简体   繁体   English

获得 LSTM 的 top 3 预测,而不是仅获得 top

[英]Get top 3 prediction of LSTM instead of only the top

I have a LSTM model trained on text content.我有一个针对文本内容训练的 LSTM 模型。 And now I want to use that model to generate some sentences.现在我想用那个模型来生成一些句子。 But instead of always picking the best option, i want it to select from for example the top 3, so that it can produce different sentences with the same input, because now I get the same answer for almost every input.但我不是总是选择最佳选项,而是希望它从例如前 3 个选项中进行选择,以便它可以使用相同的输入生成不同的句子,因为现在我几乎每个输入都得到相同的答案。 How do i modify this code so that is possible, I know I need to remove the np.argmax but i don't know how to the return the index of the top 3 highest values.如何修改此代码以便可能,我知道我需要删除np.argmax但我不知道如何返回前 3 个最高值的索引。

Current code:当前代码:

def prediction(seed_text, next_words): 
  for _ in range(next_words):
    token_list = tokenizer.texts_to_sequences([seed_text])[0]
    token_list = pad_sequences([token_list], maxlen=max_seq_length-1, padding='pre')
    predicted = np.argmax(model.predict(token_list, verbose=0), axis=-1)
    ouput_word = ""
    for word, index in tokenizer.word_index.items():
      if index == predicted:
        output_word = word
        break
  
    seed_text += ' '+output_word
  return seed_text

np.argsort will give you the indices of the items in an array in the order that sorts them small to large: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html np.argsort将按照从小到大的顺序为您提供数组中项目的索引: https ://numpy.org/doc/stable/reference/generated/numpy.argsort.html

Here's an example using argsort .这是一个使用argsort的示例。 Note that the one with the lowest prediction (index 2, "c" with the predicted value of 0.05) is left out of what is printed.请注意,预测值最低的那个(索引 2,“c”,预测值为 0.05)被排除在打印的内容之外。

import numpy as np

word_index = {'a': 0, 'b': 1, 'c': 2, 'd': 3}

predictions = np.array([0.1, 0.7, 0.05, 0.15])

# add negative to sort large to small; slice to select just up to 3rd index
top_3 = np.argsort(-predictions)[:3]

for word, index in word_index.items():
    if index in top_3:
        print(word)
#> a
#> b
#> d

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM