[英]Keras fit_generator() for long signals
我想制作一个 LSTM 网络,我有一个很长的信号,我想用作我的训练数据
如果我尝试将 CSV 文件作为数据框或 Python 中的数组上传,则会超出 RAM 的限制,因此我的想法是使用 fit_generator()。 我之前在制作图像模型时使用过这个,但后来我只使用了一些预构建生成器,但我真的找不到任何用于信号的预构建生成器,所以我决定尝试自己制作一个。
def generate_data_to_model(y_train):
while True:
with open("/mypath/myData.csv") as f:
for line in f:
x= line.rstrip('\n').split(",")
x= np.asarray(x)
x=x[1:]
x= x.reshape(1,1,12)
yield (x, y_train[line])
model = Sequential()
model.add(LSTM(32, input_shape=(1, 12)))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(9, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer="Adam", metrics=['acc'])
model.fit_generator(generate_data_to_model(y_train),
steps_per_epoch=1, epochs=2, verbose=1)
当我开始训练时,我收到此错误:
<ipython-input-29-33b2195fea44> in generate_data_to_model(y_train)
8 x= x.reshape(1,1,12)
---> 9 yield (x, y_train[line])
只有整数、切片 (
:
)、省略号 (...
)、 numpy.newaxis (None
) 和整数或布尔数组是有效索引
我从这篇文章中找到了一些很好的帮助来解决这个问题。 我发现真正方便的是制作两个发电机。 一个用于特征,或 X,一个用于标签,或 y。
def generate_X():
while True:
with open("/mypath/myData.csv") as f:
for line in f:
x= line.rstrip('\n').split(",")
x= np.asarray(x)
x=x[1:]
x= x.reshape(1,1,12)
yield x
def generate_y():
while True:
for i in range(len(y_train)):
y= y_train[i]
yield y
然后我继续使用这两个生成器作为我的第三个生成器的输入,它将为 fit_generator() 生成所需大小的批次
def batch_generator(batch_size, gen_x,gen_y):
batch_features = np.zeros((batch_size,1, 12))
batch_labels = np.zeros((batch_size,9))
while True:
for i in range(batch_size):
batch_features[i] = next(gen_x)
batch_labels[i] = next(gen_y)
yield batch_features, batch_labels
最后,但同样重要的是,我现在可以在训练模型时使用 batch_generator
model.fit_generator(batch_generator(128, generate_X(), generate_y()),
steps_per_epoch=(len(y_train)/128), epochs=2, verbose=1)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.