简体   繁体   中英

NotFoundError when using a Tensorflow optimizer on complex variables on a GPU

I'm trying to run TensorFlow on a GPU, performing SGD optimization on a complex variable. A minimal working example of the code I'm using is:

import tensorflow as tf
import numpy as np

# Check GPU is present
print(tf.config.list_physical_devices('GPU'))

# Initialise a complex matrix
mat = tf.random.uniform([1000, 1000], dtype=tf.float64)
mat = tf.complex(mat, mat)

var = tf.Variable(mat, trainable=True)

# Return the squared norm of this matrix as the loss function
def lossFn():
    return tf.math.abs(tf.linalg.trace(var @ tf.linalg.adjoint(var)))

# SGD optimizer
opt = tf.keras.optimizers.SGD(learning_rate=0.01)

numSteps=0
while numSteps < 100:
    with tf.GradientTape() as tape:
        loss = lossFn()
    grads = tape.gradient(loss, [var])

    # This is the step that fails
    opt.apply_gradients(zip(grads, [var]))
    numSteps += 1
    print(loss.numpy())

This works on a CPU fine, but on my tensorflow-gpu build it fails with the error

tensorflow.python.framework.errors_impl.NotFoundError: No registered 'ResourceApplyGradientDescent' OpKernel for 'GPU' devices compatible with node {{node ResourceApplyGradientDescent}}
(OpKernel was found, but attributes didn't match) Requested Attributes: T=DT_COMPLEX128, use_locking=true

There's also a list of devices and their attributes.

If I comment out the line where I make mat a complex variable, the code runs fine. So the issue is with GPU handling of complex numbers.

Has anybody had a similar problem/know a fix for this?

Tensorflow Build details:

  • System: Linux
  • Installation: using conda
  • Tf version: 2.2.0
  • CudNN: 7.6.5
  • Cuda: 10.1
  • Python: 3.8.5

I raised an issue on the TensorFlow GitHub and have been told that complex types are not supported on GPU for now, due to a lack of support in eigen, which is not currently a priority feature.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM