I'm working on a project about detecting drum onsets from audio. I've currently preprocessed my training data and tried putting together a SimpleRNN neural network in tensorflow but couldn't get the two working together.
During each timestep, my input consists of a 1D tensor of shape (84), and the output should be a tensor of shape (3).
Currently my code looks like this:
train_epochs = 10
batch_num = 10
learning_Rate = 0.001
''' I also tried using tf.dataset but couldn't get it to work
train_dataset = dataset.batch(batch_num, drop_remainder=True)
test_dataset = dataset.take(10000).batch(batch_num, drop_remainder=True)
print(train_dataset.element_spec)
'''
x_data = x_data[:70000]
y_data = y_data[:70000]
x_data.resize((70000, 84))
y_data.resize((70000, 3))
print(x_data.shape, y_data.shape)
model = keras.Sequential()
model.add(keras.Input(shape=(None,84)))
model.add(layers.SimpleRNN(200,activation='relu', dropout=0.2))
model.add(layers.Dense(3, activation='sigmoid'))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=learning_Rate),
loss=keras.losses.BinaryCrossentropy(),
#metrics F measure
metrics=['acc',f1_m,precision_m, recall_m]
)
model.summary()
history = model.fit(
x_data,y_data,
epochs=train_epochs,
batch_size=batch_num,
# We pass some validation for
# monitoring validation loss and metrics
# at the end of each epoch
validation_data=(x_data, y_data)
)
print("Evaluate on test data")
results = model.evaluate(test_dataset)
print("test loss, test acc:", results)
When I execute it, it gives me the error message:
ValueError: Input 0 of layer sequential_35 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (10, 84)
If I change x_data and y_data to the shape (7000,10, 84) (7000,10, 3), the error message becomes
ValueError: logits and labels must have the same shape ((10, 3) vs (10, 10, 3))
How can I fix this issue? I'm very new to deep learning, so any advice on how to work on the project is much appreciated.
Input of simpleRNN should be 3D:
x_data.resize((70000, 84, 1))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.