简体   繁体   中英

Simple RNN Python Tensorflow error on model creation

I'm running sample code taken directly from one of google examples for creating a RNN but I get an error when running it. I'm running it on VisualStudio 2019, Windows 10 x64 with i7-10510U and mx230

The Code:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential()
# Add an Embedding layer expecting input vocab of size 1000, and
# output embedding dimension of size 64.
model.add(layers.Embedding(input_dim=1000, output_dim=64))

# Add a LSTM layer with 128 internal units.
model.add(layers.SimpleRNN(128))

# Add a Dense layer with 10 units.
model.add(layers.Dense(10))

model.summary()

The error on model.add(layers.SimpleRNN(128)):

Cannot convert a symbolic Tensor (simple_rnn/strided_slice:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

You can try to upgrade Tensorflow to the latest version. I am able to execute code without any issues in Tensorflow 2.5.0 as shown below

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential()
model.add(layers.Embedding(input_dim=1000, output_dim=64))
model.add(layers.SimpleRNN(128))
model.add(layers.Dense(10))

model.summary()

Output:

2.5.0
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 64)          64000     
_________________________________________________________________
simple_rnn (SimpleRNN)       (None, 128)               24704     
_________________________________________________________________
dense (Dense)                (None, 10)                1290      
=================================================================
Total params: 89,994
Trainable params: 89,994
Non-trainable params: 0
_________________________________________________________________

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM