I'm creating an LSTM Model for Text generation using Keras. As the dataset(around 25 novels,which has around 1.4 million words) I'm using can't be processed at once(An Memory issue with converting my outputs to_Categorical()) I created a custom generator function to read the Data in.
# Data generator for fit and evaluate
def generator(batch_size):
start = 0
end = batch_size
while True:
x = sequences[start:end,:-1]
#print(x)
y = sequences[start:end,-1]
y = to_categorical(y, num_classes=vocab_size)
#print(y)
yield x, y
if batch_size == len(lines):
break;
else:
start += batch_size
end += batch_size
when i excecute the model.fit() method, after 1 epoch is done training the following error is thrown.
UnknownError: [_Derived_] CUDNN_STATUS_BAD_PARAM
in tensorflow/stream_executor/cuda/cuda_dnn.cc(1459): 'cudnnSetTensorNdDescriptor( tensor_desc.get(), data_type, sizeof(dims) / sizeof(dims[0]), dims, strides)'
[[{{node CudnnRNN}}]]
[[sequential/lstm/StatefulPartitionedCall]] [Op:__inference_train_function_25138]
Function call stack:
train_function -> train_function -> train_function
does anyone know how to solve this issue? Thanks
From many sources in the Internet, this issue seems to occur while using LSTM Layer
along with Masking Layer
and while training on GPU
.
Mentioned below can be the workarounds for this problem:
If you can compromise on speed, you can Train your Model
on CPU
rather than on GPU
. It works without any error.
As per this comment , please check if your Input Sequences
comprises of all Zeros, as the Masking Layer
may mask all the Inputs
If possible, you can Disable the Eager Execution
. As per this comment , it works without any error.
Instead of using a Masking Layer, you can try the alternatives mentioned in this link
a. Adding the argument, mask_zero = True
to the Embedding Layer
. or
b. Pass a mask argument manually when calling layers that support this argument
Last solution can be to remove Masking Layer
, if that is possible.
If none of the above workaround solves your problem, Google Tensorflow Team is working to resolve this error. We may have to wait till that is fixed.
Hope this information helps. Happy Learning!
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.