简体   繁体   中英

Determining tensor shapes at time of graph creation in TensorFlow

I'm trying to write a chunk of reusable code that reads the shape of one tensor and then uses the resulting object to define the shape of other tensors. I have a choice of reading the dynamic shape of the tensor with tf.shape(tensor) or the static shape of the tensor with tensor.get_shape() . The toy example looks like this (with the two different strategies):

def my_function_strategy_1(x, y):
    x_shape = tf.shape(x)
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

def my_function_strategy_2(x, y):
    x_shape = x.get_shape()
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

I want to use this chunk of code in different graphs. Sometimes the shape of the input tensors will be known and sometimes they will be unknown:

graph_A = tf.Graph()
with graph_A.as_default():
    x = tf.placeholder(tf.float32, [2, 4])
    y = tf.placeholder(tf.float32, [8])
    a, b, c, d = my_function(x, y)

with graph_B.as_default():
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)

The behavior I want is: (A) When the shapes of the input tensors are known (as in graph_A ), I want TensorFlow to calculate all of the shapes in the graph at graph creation time (so it can efficiently allocate resources, etc.), and (B) When the shapes of the input tensors are unknown (as in graph_B ), I want the TensorFlow to wait until runtime to calculate all of the shapes in the graph.

The strategy_1 version of the function almost does this. It achieves (B), but it doesn't quite achieve (A) because TensorFlow leaves the shape of some tensors unknown. For example, in the toy example above, the shapes of a , b , and c are calculated at graph creation time, but the shape of d is left unknown (even though d uses very similar operations). You can check this by printing a.get_shape() , b.get_shape() , etc.

Conversely, the strategy_2 version of the function achieves (A) for all tensors in the graph, but doesn't achieve (B) because TensorFlow (understandably) throws an exception when it tries to use the (unknown) static shape of the input tensor to shape other tensors.

Is there a way to achieve both (A) and (B) in a single function? How/why does the strategy_1 version work for most tensors in the graph, but not all?

You can carefully pick the elements of the shape that you know to have a "best of both worlds" result:

def my_get_shape(tensor):
    if tensor.shape.ndims is None:
        # Fully dynamic
        return tf.shape(tensor)
    if tensor.shape.is_fully_defined():
        # Fully static
        return tensor.shape
    # Partially static
    dyn_shape = tf.shape(tensor)
    shape = []
    for i, d in enumerate(tensor.shape):
        shape.append(d.value if d.value is not None else dyn_shape[i])
    return shape

def my_function(x, y):
    x_shape = my_get_shape(x)  # Or just tf.shape(x)! - see edit
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

# Fully static
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [2, 4])
    y = tf.placeholder(tf.float32, [8])
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: (2, 4) , b: (2, 4) , c: (2, 4) , d: (2, 4)

# Fully dynamic
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: <unknown> , b: <unknown> , c: (?, 4) , d: (?, 4)

# Partially static
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [None, 4])
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: (?, 4) , b: (?, 4) , c: (?, 4) , d: (?, 4)

EDIT:

Actually, replacing my_get_shape with tf.shape in the previous snippet works exacly the same. It seems that tf.shape should be the default (being careful not to cram the graph with it) unless you explicitly want to keep dimensions undefined.

I have investigated a bit, and I couldn't work the whole thing out completely. I don't know if this is useful, but here are some things I found out. Apparently TensorFlow has, at C++ level (it seems it used to be in Python before, but not anymore), a "shape inference" mechanism. If you look, for example, in tensorflow/core/ops/array_ops.cc ) you will see that every operation declaration includes a .SetShapeFn at the end, which is a function that uses InferenceContext to try to guess the output shape of the operation. This class can, among other things , check whether values in a tensor are already known, which is true for example for tf.shape when the given tensor is static or for tf.fill (and related like tf.ones ) with known values. The resolution of the shape inference algorithm is what is set as tensor shape in Python, and it can be called directly (although I don't see how it can be useful) through call_cpp_shape_fn :

from tensorflow.python.framework.common_shapes import call_cpp_shape_fn
with tf.Graph().as_default():
    print(call_cpp_shape_fn(tf.reshape(tf.placeholder(tf.float32), tf.fill([2], 3)).op))
    # Shows this:
    # {
    #   'shapes': [dim { size: 3 } dim { size: 3 }],
    #   'handle_data': [None],
    #   'inputs_needed': b'\x12\x01\x01'
    # }
    print(call_cpp_shape_fn(tf.reshape(tf.placeholder(tf.float32), (2 * tf.fill([2], 3))).op))
    # Shows this:
    # {
    #   'shapes': [dim { size: -1 } dim { size: -1 }],
    #   'handle_data': [None],
    #   'inputs_needed': b'\x12\x01\x01'
    # }

You can see that, while tf.fill([2], 3) was correctly inspected, TensorFlow didn't work out that 2 * tf.fill([2], 3) is [6, 6] , presumably because statically keeping track of operations like multiplication, even if operands are known constants, was deemed too expensive.

What I haven't found out is where do ops declare that their values can be statically known, or where/how these values are retrieved exactly. It seems that, for example, for tf.shape , it is able to specifically pick known values and leave the rest as undefined.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM