简体   繁体   中英

Tensorflow Training Crashes in last step of first epoch for audio classifier

I was trying to convert a custom DNN network originally written in Pytorch to Tensorflow 2 that has input shape required as (batchsize,39,101,1). After using the feature extractor and getting dimensions of the train and validation datasets with the compatible dimensions- (total_samples,39,101,1), I tried to train the model using model.fit. In the very last step of the first epoch, I am facing an error as listed below, that suggests my tensor shape is changing somehow. I don't know why this happening in the last step only. My callbacks are quite simple ones related to model checkpoint and earlystopping as displayed below.My DB is from Google Speech dataset v0.01. I believe this is happening before the validation step starts. Would be great if anyone could help with any suggestions to get it fixed.

Here are my train and validation datasets dimensions:

print(x_tr.shape) -> (17049, 39, 101, 1)
print(y_tr.shape) -> (17049, 10)
print(x_val.shape) -> (4263, 39, 101, 1)
print(y_val.shape) -> (4263, 10)
modelname, input_shape, numclass = 'CRNN', (39,101,1), 10

model = modelcreator.getmodel(modelname, input_shape, numclass)
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10, min_delta=0.0001) 
mc = ModelCheckpoint('best_model.hdf5', monitor='val_acc', verbose=1, save_best_only=True, mode='max')
history=model.fit(x_tr, y_tr ,epochs=100, callbacks=[es,mc], batch_size=64, validation_data=(x_val,y_val))

Epoch 1/100 266/267 [============================>.] - ETA: 0s - loss: 0.9436 - accuracy: 0.6963

--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) in ----> 1 history=model.fit(x_tr, y_tr,epochs=100, callbacks=[es,mc], batch_size=64, validation_data=(x_val,y_val))

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs) 106 def _method_wrapper(self, *args, **kwargs): 107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access --> 108 return method(self, *args, **kwargs) 109 110 # Running inside run_distribute_coordinator already.

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing) 1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step) -> 1098 tmp_logs = train_function(iterator) 1099 if data_handler.should_sync: 1100
context.async_wait()

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in call (self, *args, **kwds) 778 else: 779 compiler = "nonXla" --> 780 result = self._call(*args, **kwds) 781 782 new_tracing_count = self._get_tracing_count()

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds) 805 # In this case we have created variables on the first call, so we run the 806 # defunned version which is guaranteed to never create variables. --> 807 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 808 elif self._stateful_fn is not None: 809 # Release the lock early so that multiple threads can perform the call

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in call (self, *args, **kwargs) 2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs) -> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access 2830 2831 @property

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager) 1846
resource_variable_ops.BaseResourceVariable))], 1847
captured_inputs=self.captured_inputs, -> 1848 cancellation_manager=cancellation_manager) 1849 1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call( -> 1924 ctx, args, cancellation_manager=cancellation_manager)) 1925
forward_backward = self._select_forward_and_backward_functions(
1926 args,

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager) 548 inputs=args, 549 attrs=attrs, --> 550 ctx=ctx) 551 else: 552 outputs = execute.execute_with_cancellation(

~/Desktop/Spoken_Keyword_Spotting/newenv/lib/python3.6/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 58 ctx.ensure_initialized() 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, ---> 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: 62 if name is not None:

InvalidArgumentError: Specified a list with shape [64,512] from a tensor with shape [25,512] [[{{node TensorArrayUnstack/TensorListFromTensor}}]]
[[functional_3/lstm_1/PartitionedCall]] [Op:__inference_train_function_13255]

Function call stack: train_function -> train_function -> train_function

I figured it from the link below. Based on the answer here: InvalidArgumentError: Specified a list with shape [60,9] from a tensor with shape [56,9]

My model consisted of a LSTM layer at the end with return sequences=True. So in the last epoch as the batch that received had fewer samples than batch size. Faced this issue

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM