繁体   English   中英

LSTM神经网络输入/输出尺寸错误

[英]LSTM Nerual Network Input/Output dimensions error

我是TensorFlow和LSTM架构的新手。 我在为数据集计算输入和输出(x_train,x_test,y_train,y_test)时遇到问题。

我最初输入的形状:

  • X_train:(366,4)
  • X_test:(104,4)
  • Y_train:(366,)
  • Y_test:(104,)

Ytrain和Ytest是一系列股票价格。 Xtrain和Xtest是我要学习预测股票价格的四个功能。

# Splitting the training and testing data

train_start_date = '2010-01-08'
train_end_date = '2017-01-06'
test_start_date = '2017-01-13'
test_end_date = '2019-01-04'

train = df.ix[train_start_date : train_end_date]
test = df.ix[test_start_date:test_end_date]


X_test = sentimentScorer(test)
X_train = sentimentScorer(train)

Y_test = test['prices'] 
Y_train = train['prices']

#Conversion in 3D array for LSTM INPUT

X_test = X_test.reshape(1, 104, 4)
X_train = X_train.reshape(1, 366, 4)





model = Sequential()

model.add(LSTM(128, input_shape=(366,4), activation='relu', 
return_sequences=True))
model.add(Dropout(0.2))

model.add(LSTM(128, activation='relu'))
model.add(Dropout(0.1))

model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(10, activation='softmax'))

opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-6)

# Compile model
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=opt,
    metrics=['accuracy'],
)

model.fit(X_train,
          Y_train,
          epochs=3,
          validation_data=(X_test, Y_test))

这是生成的错误:

-------------------------------------------------- ------------------------- ValueError Traceback(最近一次通话最后一次)以65 Y_train,66 epochs = 3,---> 67validation_data =( X_test,Y_test))

c:\\ users \\ talal \\ appdata \\ local \\ programs \\ python \\ python36 \\ lib \\ site-packages \\ tensorflow \\ python \\ keras \\ engine \\ training.py in fit(self,x,y,batch_size,epochs,verbose,回调,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,** kwargs)1507 steps_name ='steps_per_epoch',1508 steps = steps_per_epoch,-> 1509validation_split = validation_split)1510 1511#准备验证数据。

C:\\ users \\ talal \\ appdata \\ local \\ programs \\ python \\ python36 \\ lib \\ site-packages \\ tensorflow \\ python \\ keras \\ engine \\ training.py in _standardize_user_data(self,x,y,sample_weight,class_weight,batch_size,check_steps ,steps_name,steps,validation_split)991 x,y = next_element 992 x,y,sample_weights = self._standardize_weights(x,y,sample_weight,-> 993 class_weight,batch_size)994 return x,y,sample_weights 995

c:\\ users \\ talal \\ appdata \\ local \\ programs \\ python \\ python36 \\ lib \\ site-packages \\ tensorflow \\ python \\ keras \\ engine \\ training.py in _standardize_weights(self,x,y,sample_weight,class_weight,batch_size)1110 feed_input_shapes,1111
check_batch_axis = False,#不强制执行批量大小。 -> 1112 exception_prefix ='input')1113 1114如果y不为None:

c:\\ users \\ talal \\ appdata \\ local \\ programs \\ python \\ python36 \\ lib \\ site-packages \\ tensorflow \\ python \\ keras \\ engine \\ training_utils.py in standardize_input_data(数据,名称,形状,check_batch_axis,exception_prefix)314':预期的'+名称[i] +'具有'+ 315 str(len(shape))+'尺寸,但得到数组'-> 316'具有形状'+ str(data_shape))如果不是check_batch_axis,则为317:318 data_shape = data_shape [1:]

ValueError:检查输入时出错:预期lstm_18_input具有3个维,但数组的形状为(366,4)

X_train的尺寸错误。LSTM仅接受3维输入。您说的是具有4个功能。 假设366是您对一个样本的时间戳数,则您的输入应为(num_samples,366,4)个形状欢呼声:-)

您的代码几乎可以了。

您的y_testy_train应该是一个具有一个元素或形状为(1,1)的数组,没关系。

但是,您的输入形状有误,第一个LSTM应该是:

model.add(LSTM(128, input_shape=(None,4), activation='relu', return_sequences=True))

注意None ,因为测试序列和训练序列的长度不同,所以无法指定它(并且Keras接受未指定的第一维)。 错误是由于长度分别为366和104。 如果要对keras.preprocessing.sequence.pad_sequences使用批处理,则应使用keras.preprocessing.sequence.pad_sequences进行零填充。

无需批量指定input_shape ,网络的其余部分应该没问题。

而且,如果要执行回归而不是分类(可能是这种情况),则应执行@Ankish Bansal编写的最后两个步骤,例如,将损耗更改为mean squared error ,并使最后一层的输出值为1而不是10。

  1. LSTM期望将调光输入为(num_examples,seq_length,input_dims),因此输入中存在一个错误。

  2. 您正在预测维度1的输出,您的模型输出是10。

    model.add(Dense(1, activation='linear'))

  3. 另外,您正在预测价格,这就是回归问题。 但是,您正在使用分类设置。 尝试这个

    model.compile(loss='mse', optimizer='adam', metrics=['mean_squared_error'])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM