繁体   English   中英

如何为 RNN 输入多个时间序列数据(wave)

[英]How to input multiple time-series data(wave) for RNN

我制作了简单的 RNN 模型来学习和拟合一个波形文件。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM,Dropout,Dense
from tensorflow.keras.layers import SimpleRNN
import librosa
import librosa.display
import numpy as np


a, sr = librosa.load("A.wav",sr=22050)
rawdata = librosa.stft(a, n_fft=512,hop_length= 512 // 4, window='hann') #make fourier transfered A.wav data.
rawdata = rawdata.transpose() # [Frame,Freq] => [Freq,Frame]

input_len = 10 # the frame number for learning to make next one
input=[]
target=[]
for i in range(0, len(rawdata) - input_len): 
    input.append( rawdata[i:i+input_len] ) # frames
    target.append( rawdata[i+input_len] )  # one step forward frame for answer.
 
X = np.array(input)
Y = np.array(target)

#Separate 8:2 for training and test
x, val_x, y, val_y = train_test_split(X, Y, test_size=int(X.shape[0] * 0.2), shuffle=False)

n_hidden = 512
epoch = 100
model = Sequential()

model.add(SimpleRNN(n_hidden, input_shape=(input_len, n_in), return_sequences=False))
model.add(Dense(n_hidden, activation="linear")) 
model.add(Dense(n_in, activation="linear"))

opt = Adam(lr=0.001)
model.compile(loss='mse', optimizer=opt)

model.summary()

history = model.fit(x, y, epochs=epoch, batch_size=10,validation_data=(val_x, val_y))

好的,它工作正常。 它学习一个波形文件A.wav

但是我怎样才能学习多个波形文件?

B.wav C.wav

例如,

如果我为每个 wav 多次使用 model.fit(),这个模型是否记得过去的学习?

是的,模型在拟合期间确实记得之前的训练,您也可以多次使用拟合。 但最好使用model.train_on_batch这是适用于小批量数据的简单版本。

您还可以修改代码以将其他 wav 文件功能添加到数据中。

# second way
input_len = 10 # the frame number for learning to make next one
input=[]
target=[]

for f in ['A.wav','B.wav','C.wav']:
  a, sr = librosa.load(f,sr=22050)
  rawdata = librosa.stft(a, n_fft=512,hop_length= 512 // 4, window='hann') #make fourier transfered A.wav data.
  rawdata = rawdata.transpose() # [Frame,Freq] => [Freq,Frame]

  for i in range(0, len(rawdata) - input_len): 
    input.append( rawdata[i:i+input_len] ) # frames
    target.append( rawdata[i+input_len] )  # one step forward frame for answer.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM