繁体   English   中英

Keras fit_generator() 用于长信号

[英]Keras fit_generator() for long signals

我想制作一个 LSTM 网络,我有一个很长的信号,我想用作我的训练数据

  • 我的X_train是一个 CSV 文件,其中包含 12 个信号,长度为54 837 488
  • 我的y_train是一个数组,其中包含长度为54 837 488的 One Hot 编码信号( 9 个类别

如果我尝试将 CSV 文件作为数据框或 Python 中的数组上传,则会超出 RAM 的限制,因此我的想法是使用 fit_generator()。 我之前在制作图像模型时使用过这个,但后来我只使用了一些预构建生成器,但我真的找不到任何用于信号的预构建生成器,所以我决定尝试自己制作一个。

def generate_data_to_model(y_train):
    while True:
        with open("/mypath/myData.csv") as f:
            for line in f:
                x= line.rstrip('\n').split(",")
                x= np.asarray(x)
                x=x[1:]
                x= x.reshape(1,1,12)
                yield (x, y_train[line])


model = Sequential()
model.add(LSTM(32, input_shape=(1, 12)))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(9, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer="Adam", metrics=['acc'])

model.fit_generator(generate_data_to_model(y_train),
                    steps_per_epoch=1, epochs=2, verbose=1)

当我开始训练时,我收到此错误:

<ipython-input-29-33b2195fea44> in generate_data_to_model(y_train)
      8                 x= x.reshape(1,1,12)
 ---> 9                 yield (x, y_train[line])

只有整数、切片 ( : )、省略号 ( ... )、 numpy.newaxis ( None ) 和整数或布尔数组是有效索引

我从这篇文章中找到了一些很好的帮助来解决这个问题。 我发现真正方便的是制作两个发电机。 一个用于特征,或 X,一个用于标签,或 y。

def generate_X():
    while True:
        with open("/mypath/myData.csv") as f:
            for line in f:
                x= line.rstrip('\n').split(",")
                x= np.asarray(x)
                x=x[1:]
                x= x.reshape(1,1,12)
                yield x

def generate_y():
    while True:
        for i in range(len(y_train)):
            y= y_train[i]
            yield y

然后我继续使用这两个生成器作为我的第三个生成器的输入,它将为 fit_generator() 生成所需大小的批次

def batch_generator(batch_size, gen_x,gen_y): 

    batch_features = np.zeros((batch_size,1, 12))
    batch_labels = np.zeros((batch_size,9))

    while True:
        for i in range(batch_size):
            batch_features[i] = next(gen_x)
            batch_labels[i] = next(gen_y)
        yield batch_features, batch_labels

最后,但同样重要的是,我现在可以在训练模型时使用 batch_generator

model.fit_generator(batch_generator(128, generate_X(), generate_y()),
                    steps_per_epoch=(len(y_train)/128), epochs=2, verbose=1)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM