简体   繁体   中英

How to convert Tensor to string

I'm testing out the tf.data() which is the recommended way of feeding data in batches now, however, I'm loading a custom dataset, so I need the file names in 'str' format. But when creating a tf.Dataset.from_tensor_slices, they're Tensor objects.

def load_image(file, label):
        nifti = np.asarray(nibabel.load(file).get_fdata()) # <- here is the problem

        xs, ys, zs = np.where(nifti != 0)
        nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
        nifti = nifti[0:100, 0:100, 0:100]
        nifti = np.reshape(nifti, (100, 100, 100, 1))
        nifti = tf.convert_to_tensor(nifti, np.float32)
        return nifti, label


    def load_image_wrapper(file, labels):
       file = tf.py_function(load_image, [file, labels], (tf.string, tf.int32))
       return file


    dataset = tf.data.Dataset.from_tensor_slices((train, labels))
    dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
    dataset = dataset.batch(6)
    dataset = dataset.prefetch(buffer_size=6)
    iterator = iter(dataset)
    batch_of_images = iterator.get_next()

Here is the error: typeerror expected str bytes or os.pathlike object not Tensor

I've tried using a 'py_function' wrapper, to no avail. Any ideas?

Solved the probelm, TensorFlow 2.1:

    def load_image(file, label):
    nifti = np.asarray(nibabel.load(file.numpy().decode('utf-8')).get_fdata())

    xs, ys, zs = np.where(nifti != 0)
    nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
    nifti = nifti[0:100, 0:100, 0:100]
    nifti = np.reshape(nifti, (100, 100, 100, 1))
    nifti = tf.convert_to_tensor(nifti, np.float64)
    return nifti, label


def load_image_wrapper(file, labels):
    return tf.py_function(load_image, [file, labels], [tf.float64, tf.float64])


dataset = tf.data.Dataset.from_tensor_slices((train, labels))
dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
dataset = dataset.batch(2)
dataset = dataset.prefetch(buffer_size=2)
iterator = iter(dataset)
batch_of_images = iterator.get_next()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM