简体   繁体   中英

How to access TensorArray elements with tf.while_loop in Tensorflow?

I have a question relative to using TensorArray.

The Problem:
I would like access elements of a TensorArray with a tf.while_loop . Please note that I am able to read the contents of the TensorArray using for example, u1.read(0) .

My current code:
Here is what I have so far:

embeds_raw = tf.constant(np.array([
    [1, 1],
    [1, 1],
    [2, 2],
    [3, 3],
    [3, 3],
    [3, 3]
], dtype='float32'))
embeds = tf.Variable(initial_value=embeds_raw)
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable')
sen_len = tf.placeholder('int32', shape=[None], name='sen_len')
# max_l = tf.reduce_max(sen_len)
current_size = tf.shape(sen_len)[0]
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT')
added_container_variable = tf.add(container_variable, padded_sen_len)
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False)
u1 = u1.split(embeds, added_container_variable)

sentences = []
i = 0

def condition(_i, _t_array):
    return tf.less(_i, current_size)

def body(_i, _t_array):
    sentences.append(_t_array.read(_i))
    return _i + 1, _t_array

idx, arr = tf.while_loop(condition, body, [i, u1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]})
    print(sents)

The error message:

Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 267, in init fetch, allow_tensor=True, allow_operation=True)) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2584, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2673, in _as_graph_element_locked % (type(obj). name , types_str)) TypeError: Can not convert a TensorArray into a Tensor or Operation.

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 191, in main() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 187, in main variable_container() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 179, in variable_container sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 984, in _run self._graph, fetches, feed_dict_string, feed_handles=feed_handles) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 410, in init self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 238, in for_fetch return _ElementFetchMapper(fetches, contractio n_fn) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 271, in init % (fetch, type(fetch), str(e))) TypeError: Fetch argument has invalid type , must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)

I don't have enough reputation to comment, so I'll write an answer.

I don't quite understand what your code is intended to do, but the exception is because sess.run() returns Tensor s, whereas arr is a TensorArray . You could do, for example:

sents = sess.run(arr.concat(), feed_dict={sen_len: [2, 1, 3]})

Of course, that just undoes your split. If you want to get all the values out, maybe:

sents = sess.run([arr.read(i) for i in range(512)], feed_dict={sen_len: [2, 1, 3]})

But I'm sure there must be a cleaner way than hardcoding 512. And presumably your while_loop is meant to be doing something.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM