繁体   English   中英

文本预测 LSTM 神经网络的问题

[英]Problem with text prediction LSTM neural networks

我正在尝试使用带有书籍数据集的递归神经网络(LSTM)进行文本预测。 不管我尝试改变图层大小或其他参数多少,它总是过拟合。

我一直在尝试更改层数、LSTM 层中的单元数、正则化、归一化、batch_size、shuffle 训练数据/验证数据、将数据集更改为更大。 现在我尝试使用 ~140kb txt 书。 我也试过 200kb、1mb、5mb。

创建训练/验证数据:

sequence_length = 30

x_data = []
y_data = []

for i in range(0, len(text) - sequence_length, 1):
    x_sequence = text[i:i + sequence_length]
    y_label = text[i + sequence_length]

    x_data.append([char2idx[char] for char in x_sequence])
    y_data.append(char2idx[y_label])

X = np.reshape(x_data, (data_length, sequence_length, 1))
X = X/float(vocab_length)
y = np_utils.to_categorical(y_data)

# Split into training and testing set, shuffle data
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, shuffle=False)

# Shuffle testing set
X_test, y_test = shuffle(X_test, y_test, random_state=0)

创建 model:

model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))

在此处输入图像描述 编译 model:

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

我得到以下特征: 在此处输入图像描述

我不知道如何处理这种过度拟合,因为我正在搜索互联网,尝试了很多东西,但似乎都没有奏效。

我怎样才能得到更好的结果? 这些预测现在似乎并不好。

以下是我接下来要尝试的一些事情。 (本人也是业余爱好者,如有错误请指正)

  1. 尝试从文本中提取向量表示 试用 word2vec、GloVe、FastText、ELMo。 提取向量表示,然后将它们输入网络。 您还可以创建一个嵌入层来帮助解决这个问题。 这个 博客有更多信息。
  2. 256 经常性单位可能太多。 我认为永远不应该从庞大的网络开始。 从小处着手。 看看你是否欠拟合。 如果是,则 go 更大。
  3. 关闭优化器。 我发现亚当倾向于过度拟合。 我在 rmsprop 和 Adadelta 方面取得了更好的成功。
  4. 也许, 注意力就是你所需要的? Transformers 最近为 NLP 做出了巨大贡献。 也许您可以尝试在您的网络中实现简单的软注意机制 如果您还不熟悉,这里有一个不错的视频系列 关于它的交互式研究论文
  5. CNN 在 NLP 应用程序中也非常出色。 尽管它们直观地对文本数据没有任何意义(对大多数人来说)。 也许你可以尝试利用它们,堆叠它,等等。玩。 这是有关如何将其用于句子分类的指南 我知道,你的域是不同的。 但我认为直觉继续存在。 :)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM