简体   繁体   中英

How to convert a scalar tensor to scalar in a model in TensorFlow?

I'm trying to add a data augmentation function to the TensorFlow MNIS example mnist_deep.py by using tf.contrib.image.rotate()

rotate_angle = 0.1


def deepnn(x):
    ...
    with tf.name_scope('rotate'):
        angle = tf.tf.placeholder(tf.float32)
        x_image = tf.contrib.image.rotate(x_image, angle)  # Wrong!
    ...
    return angle


...
angle = deepnn(x)
with tf.Session() as sess:
    angle.eval({angle: rotate_angle}

This does not work since tf.contrib.image.rotate() accepts only plain scalars as the angle.

I tried TensorFlow: cast a float64 tensor to float32 but sadly the mentioned function now returns a tensor as well.

How should I convert the tensor scalar to scalar in a model itself? I would like to reuse the same model and provide different angles for training and testing.

I don't think you'll need strange conversions but some re-organization of the code. I found a possible solution to your problem, I hope that it is suitable for you:

import tensorflow as tf
import numpy as np

rotate_angle = 0.1

def deepnn(x,angle):
    x_image = tf.contrib.image.rotate(x, angle)     
    return x_image

angle = tf.placeholder(tf.float32,shape=())
input_image_placeholder = tf.placeholder(tf.float32,shape=(100,100,3))


rotated_x_image = deepnn(input_image_placeholder,angle)


sess = tf.Session()

input_image = np.ones(dtype=float,shape=(100,100,3))

curr_rotated_x_image = sess.run(rotated_x_image,{angle:rotate_angle,input_image_placeholder:input_image})

print(curr_rotated_x_image)

sess.close()

I don't think declaring a placeholder inside a function is a good idea so I moved it outside. Let me know if this solution is ok!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM