繁体   English   中英

Keras LSTM 多类分类结构

[英]Keras LSTM Multiclass Classification structure

我是机器学习的初学者,一直在尝试使用 LSTM 根据 12 个特征将其分类为 4 个类。 我已经遵循了很多教程,但我仍然有点困惑。 我的数据集有 12 列我想用于训练,包括 label 列,它的值对应于每个 class。

0 = Class 1

1 = Class 2

2 = Class 3

3 = Class 4

这是我的代码:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import time
# For LSTM model
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from keras.callbacks import EarlyStopping
from keras import optimizers

# Load dataset
train = pd.read_csv("C:\Users\O\Documents\Datasets\FinalDataset2.csv")

train_proccessed = train.iloc[:, 1:13]

scaler = MinMaxScaler(feature_range = (0, 1))
train_scaled = scaler.fit_transform(train_proccessed)

features_set = []
labels = []
for i in range(1, 393763):
    features_set.append(train_scaled[i-1:i, 0])
    labels.append(train_scaled[i, 0])

features_set, labels = np.array(features_set), np.array(labels)

features_set = np.reshape(features_set, (features_set.shape[0], features_set.shape[1], 1))


# Initialize LSTM model
model = Sequential()

model.add(LSTM(512, return_sequences=True,  activation='tanh', input_shape=(features_set.shape[1], 1)))
model.add(Dropout(0.2))
model.add(Dense(4, activation='softmax'))
model.add(LSTM(units=1, activation='sigmoid'))
opt = optimizers.Adam(lr=0.0001)
model.compile(optimizer = opt , loss = 'categorical_crossentropy', metrics = ['accuracy'])

model.fit(features_set, labels, epochs = 100, batch_size = 512)

我非常不确定我的 model 是否正确构建。 此外,它只产生非常低的准确度(27-28%)。 任何帮助将不胜感激!!

简短的回答:

  1. 最后一层要密集(4,activation='softmax')
  2. 标签必须是一种热编码,因为您使用loss='categorical_crossentropy'

这里有更多说明可以帮助您

第一层

LSTM(512, return_sequences=True,  activation='tanh')
  • 您从庞大的 LSTM 单元开始,而您的数据只有 12 列。
  • return_sequences=True这在你的情况下是不合理的,因为你没有在它之后放置另一层

Model 本体

  • LST 和最终 Dense() 之间的中间没有层
  • 至少添加一个 Dense 层

Output层

  • 将损失用作sparse_categorical_crossentropy而不是categorical_crossentropy可能更容易,因此标签可以作为数字传递,否则您需要对其进行处理

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM