简体   繁体   English

为什么我的keras LSTM模型陷入无限循环?

[英]Why does my keras LSTM model get stuck in an infinite loop?

I am trying to build a small LSTM that can learn to write code (even if it's garbage code) by training it on existing Python code. 我正在尝试构建一个小的LSTM,该LSTM可以通过在现有的Python代码上进行训练来学习编写代码(即使它是垃圾代码)。 I have concatenated a few thousand lines of code together in one file across several hundred files, with each file ending in <eos> to signify "end of sequence". 我将几千行代码连接在一个文件中,跨越数百个文件,每个文件都以<eos>结尾,以表示“序列结束”。

As an example, my training file looks like: 例如,我的训练文件如下所示:


setup(name='Keras',
...
      ],
      packages=find_packages())
<eos>
import pyux
...
with open('api.json', 'w') as f:
    json.dump(sign, f)
<eos>

I am creating tokens from the words with: 我正在使用以下单词创建令牌:

file = open(self.textfile, 'r')
filecontents = file.read()
file.close()
filecontents = filecontents.replace("\n\n", "\n")
filecontents = filecontents.replace('\n', ' \n ')
filecontents = filecontents.replace('    ', ' \t ')

text_in_words = [w for w in filecontents.split(' ') if w != '']

self._words = set(text_in_words)
    STEP = 1
    self._codelines = []
    self._next_words = []
    for i in range(0, len(text_in_words) - self.seq_length, STEP):
        self._codelines.append(text_in_words[i: i + self.seq_length])
        self._next_words.append(text_in_words[i + self.seq_length])

My keras model is: 我的keras模型是:

model = Sequential()
model.add(Embedding(input_dim=len(self._words), output_dim=1024))

model.add(Bidirectional(
    LSTM(128), input_shape=(self.seq_length, len(self._words))))

model.add(Dropout(rate=0.5))
model.add(Dense(len(self._words)))
model.add(Activation('softmax'))

model.compile(loss='sparse_categorical_crossentropy',
              optimizer="adam", metrics=['accuracy'])

But no matter how much I train it, the model never seems to generate <eos> or even \\n . 但是,无论我训练多少,该模型似乎都不会生成<eos>甚至\\n I think it might be because my LSTM size is 128 and my seq_length is 200, but that doesn't quite make sense? 我认为这可能是因为我的LSTM大小是128seq_length是200,但这不是很有意义吗? Is there something I'm missing? 有什么我想念的吗?

Sometimes, when there is no limit for code generation or the <EOS> or <SOS> tokens are not numerical tokens LSTM never converges. 有时,当limit for code generation没有limit for code generationthe <EOS> or <SOS> tokens are not numerical tokens LSTM从不收敛。 If you could send your outputs or error messages, it would be much easier to debug. 如果您可以发送输出或错误消息,则调试起来会容易得多。

You could create an extra class for getting words and sentences. 您可以创建一个额外的类来获取单词和句子。

# tokens for start of sentence(SOS) and end of sentence(EOS)

SOS_token = 0
EOS_token = 1


class Lang:
    '''
    class for word object, storing sentences, words and word counts.
    '''
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

Then, while generating text, just adding a <SOS> token would do. 然后,在生成文本时,只需添加<SOS>令牌即可。 You can use https://github.com/sherjilozair/char-rnn-tensorflow , a character level rnn for reference. 您可以使用https://github.com/sherjilozair/char-rnn-tensorflow (字符级别rnn)作为参考。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM