[英]Keras can't save model with CuDNNLSTM as SavedModel
I have recently encountered a problem with Keras.我最近遇到了 Keras 的问题。 My model looks like:我的 model 看起来像:
inputs = Input(shape=(max_sequence_len,))
# Embedding layer
embedding = Embedding(
input_length=max_sequence_len,
input_dim=len(word_idx),
output_dim=100,
weights=[embedding_matrix],
trainable=False
)(inputs)
# Recurrent layers
heart = Bidirectional(CuDNNLSTM(256))(embedding)
dense = Dense((n_of_stocks * stock_size * 4), activation='relu')(heart)
# Fully connected layer
preoutput = []
outputs = []
for i in range(n_of_stocks):
preoutput.append(Dense(stock_size * 4, activation='linear')(dense))
outputs.append(Reshape((stock_size, 4))(preoutput[i]))
# Compile the model
model = Model(inputs=inputs, outputs=outputs, name="the_model")
model.summary()
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
model.save("mynetwork")
When I try to save model it fails with an error:当我尝试保存 model 时,它失败并出现错误:
Traceback (most recent call last):
File "D:\Projects\Project T\neural\network.py", line 95, in <module>
model.save("mynetwork")
File "C:\Users\nkart\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\nkart\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\layers\rnn\base_rnn.py", line 282, in _use_input_spec_as_call_signature
if self.unroll:
AttributeError: 'CuDNNLSTM' object has no attribute 'unroll'
Am I doing something wrong?难道我做错了什么? Should I try to save it as h5?我应该尝试将其保存为 h5 吗?
Don't use CuDNNLSTM
, just use LSTM
(which is newer) with default parameters, it will automatically use CuDNN, assuming you have CuDNN properly installed.不要使用CuDNNLSTM
,只需使用具有默认参数的LSTM
(较新),假设您正确安装了 CuDNN,它将自动使用 CuDNN。 CuDNNLSTM
is for Tensorflow <=2.0. CuDNNLSTM
适用于 Tensorflow <=2.0。
heart = Bidirectional(LSTM(256))(embedding)
You might need to use tensorflow.keras.layers
instead of keras.layers
.您可能需要使用tensorflow.keras.layers
而不是keras.layers
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.