簡體   English   中英

為什么我的keras LSTM模型陷入無限循環?

[英]Why does my keras LSTM model get stuck in an infinite loop?

我正在嘗試構建一個小的LSTM,該LSTM可以通過在現有的Python代碼上進行訓練來學習編寫代碼(即使它是垃圾代碼)。 我將幾千行代碼連接在一個文件中,跨越數百個文件,每個文件都以<eos>結尾,以表示“序列結束”。

例如,我的訓練文件如下所示:


setup(name='Keras',
...
      ],
      packages=find_packages())
<eos>
import pyux
...
with open('api.json', 'w') as f:
    json.dump(sign, f)
<eos>

我正在使用以下單詞創建令牌:

file = open(self.textfile, 'r')
filecontents = file.read()
file.close()
filecontents = filecontents.replace("\n\n", "\n")
filecontents = filecontents.replace('\n', ' \n ')
filecontents = filecontents.replace('    ', ' \t ')

text_in_words = [w for w in filecontents.split(' ') if w != '']

self._words = set(text_in_words)
    STEP = 1
    self._codelines = []
    self._next_words = []
    for i in range(0, len(text_in_words) - self.seq_length, STEP):
        self._codelines.append(text_in_words[i: i + self.seq_length])
        self._next_words.append(text_in_words[i + self.seq_length])

我的keras模型是:

model = Sequential()
model.add(Embedding(input_dim=len(self._words), output_dim=1024))

model.add(Bidirectional(
    LSTM(128), input_shape=(self.seq_length, len(self._words))))

model.add(Dropout(rate=0.5))
model.add(Dense(len(self._words)))
model.add(Activation('softmax'))

model.compile(loss='sparse_categorical_crossentropy',
              optimizer="adam", metrics=['accuracy'])

但是,無論我訓練多少,該模型似乎都不會生成<eos>甚至\\n 我認為這可能是因為我的LSTM大小是128seq_length是200,但這不是很有意義嗎? 有什么我想念的嗎?

有時,當limit for code generation沒有limit for code generationthe <EOS> or <SOS> tokens are not numerical tokens LSTM從不收斂。 如果您可以發送輸出或錯誤消息,則調試起來會容易得多。

您可以創建一個額外的類來獲取單詞和句子。

# tokens for start of sentence(SOS) and end of sentence(EOS)

SOS_token = 0
EOS_token = 1


class Lang:
    '''
    class for word object, storing sentences, words and word counts.
    '''
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

然后,在生成文本時,只需添加<SOS>令牌即可。 您可以使用https://github.com/sherjilozair/char-rnn-tensorflow (字符級別rnn)作為參考。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM