简体   繁体   中英

How to convert to Keras code from MATLAB Deep learning model

I am making the binary sound classification model by Keras on Python3.7. I have been make the sound classification model on MATLAB however some specifically layer is not installed on MATLAB (ex. GRU). So I try to convert to Keras deep learning model from MATLAB deep learning model.

The original MATLAB code is shown bellow:

inputsize=[31,69]
layers = [ ...
    sequenceInputLayer(inputsize(1))
    bilstmLayer(200,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ]
options = trainingOptions('adam', ...
    'MaxEpochs',30, ...
    'MiniBatchSize', 200, ...
    'InitialLearnRate', 0.01, ...
    'GradientThreshold', 1, ...
    'ExecutionEnvironment',"auto",...
    'plots','training-progress', ...
    'Verbose',false);

This model get to the accuracy is 0.955.

The Keras code based on MATLAB code is shown below:

# traindatasize=(86400,31,69)
inputsize=(31,69)
batchsize=200
epochs=30
model = Sequential()
model.add(Bidirectional(LSTM(200, input_shape=inputsize)))
model.add(Dense(2, activation='softmax'))

model.compile(optimizer=RMSprop(), loss='binary_crossentropy', metrics=['accuracy'])

model.fit(traindata, trainlabel, batch_size=batchsize, epochs=epochs, verbose=1)

This model get to the accuracy is 0.444

I don't understand what is the effect. The traindata used same data from STFT and normalize before train those model using standard deviation and mean average. Please some comments.

Python 3.7 on Anaconda

Keras 2.2.4

I think that's because the MATLAB code uses the Adam optimizer for training, and you defined RMSprop instead in:

model.compile(optimizer=RMSprop(),loss='binary_crossentropy',metrics=['accuracy'])

instead, use:

from keras import optimizers
adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)

...

model.compile(optimizer=adam,loss='binary_crossentropy',metrics=['accuracy'])

check if this improves the answer.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM