简体   繁体   中英

How to input multiple time-series data(wave) for RNN

I made the simple RNN model to learn and fit the one wave file.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM,Dropout,Dense
from tensorflow.keras.layers import SimpleRNN
import librosa
import librosa.display
import numpy as np


a, sr = librosa.load("A.wav",sr=22050)
rawdata = librosa.stft(a, n_fft=512,hop_length= 512 // 4, window='hann') #make fourier transfered A.wav data.
rawdata = rawdata.transpose() # [Frame,Freq] => [Freq,Frame]

input_len = 10 # the frame number for learning to make next one
input=[]
target=[]
for i in range(0, len(rawdata) - input_len): 
    input.append( rawdata[i:i+input_len] ) # frames
    target.append( rawdata[i+input_len] )  # one step forward frame for answer.
 
X = np.array(input)
Y = np.array(target)

#Separate 8:2 for training and test
x, val_x, y, val_y = train_test_split(X, Y, test_size=int(X.shape[0] * 0.2), shuffle=False)

n_hidden = 512
epoch = 100
model = Sequential()

model.add(SimpleRNN(n_hidden, input_shape=(input_len, n_in), return_sequences=False))
model.add(Dense(n_hidden, activation="linear")) 
model.add(Dense(n_in, activation="linear"))

opt = Adam(lr=0.001)
model.compile(loss='mse', optimizer=opt)

model.summary()

history = model.fit(x, y, epochs=epoch, batch_size=10,validation_data=(val_x, val_y))

OK it works fine. It learns the one wave file A.wav

However how can I learn multiple wave files??

B.wav C.wav

For example,

If I use model.fit() multiple times for each wav, does this model remember the past learning??

yes, model does remember previous train during fit, you can use fit multiple times as well. but its better to use model.train_on_batch this is simple version of fit to be used on small batch of data.

you can also modify your code to add other wav file feature to data.

# second way
input_len = 10 # the frame number for learning to make next one
input=[]
target=[]

for f in ['A.wav','B.wav','C.wav']:
  a, sr = librosa.load(f,sr=22050)
  rawdata = librosa.stft(a, n_fft=512,hop_length= 512 // 4, window='hann') #make fourier transfered A.wav data.
  rawdata = rawdata.transpose() # [Frame,Freq] => [Freq,Frame]

  for i in range(0, len(rawdata) - input_len): 
    input.append( rawdata[i:i+input_len] ) # frames
    target.append( rawdata[i+input_len] )  # one step forward frame for answer.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM