[英]TensorFlow: Split a tensor into `batch_size` slices
I have a rank-3 tensor named tensor
of shape [batch_size, axis_1, axis_2]
and want to split it into batch_size
slices along the first axis like so:我有一个名为[batch_size, axis_1, axis_2]
张tensor
的 3 级张tensor
,并希望将它沿第一个轴拆分为batch_size
切片,如下所示:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
Unfortunately, this doesn't work because the value of batch_size
isn't yet known during construction of the graph.不幸的是,这是行不通的,因为在构建图的过程中还不知道batch_size
的值。
How can I solve this?我该如何解决这个问题?
I get this error:我收到此错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
Weirdly, trying to use batch_size
in other TensorFlow functions seems to work:奇怪的是,尝试在其他 TensorFlow 函数中使用batch_size
似乎有效:
tensor = tf.reshape(tensor, [batch_size, -1])
works fine despite the fact that the value of batch_size
is unknown during graph construction.尽管在图构建期间batch_size
的值未知,但工作正常。
Is the problem particularly with tf.split()
?问题是否特别与tf.split()
?
A work-around is to do:解决方法是:
batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
elems=tf.range(batch_size),
dtype=tf.float32)
I'm still interested in better solutions though.不过,我仍然对更好的解决方案感兴趣。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.