简体   繁体   中英

How to install Tensorflow 2.0 install with Conda?

I get this error when I try to run a test example.

Failed to get convolution algorithm. This is probably because cuDNN failed to >initialize, so try looking to see if a warning log message was printed above.

I have tried a recommended process for conda found here . Create new environment install Tensorflow-GPU. Install Jupyter Notebook and test some code. I have tried changing versions of cudatoolkit and cudnn but I can't seem to figure out how to do this. The install Tensorflow-GPU puts Cudatoolkit 10.0.130 and cudnn 7.6.

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape(60000, 28, 28, 1)

test_images = test_images.reshape(10000, 28, 28, 1)

train_images, test_images = train_images/255, test_images/255

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape (28,28,1)),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

import time

start_time=time.time()

model.fit(train_images, train_labels, batch_size=128, epochs=15, verbose=1,
     validation_data=(test_images, test_labels))

print('Training took {} seconds'.format(time.time()-start_time))

For the benefit of stack overflow community, posting solution here though it presented in GitHub.

You can add below code in the beginning of program, will resolve your issue

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

For more details please refer this Github thread.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM