简体   繁体   English

TensorFlow:将张量拆分为“batch_size”切片

[英]TensorFlow: Split a tensor into `batch_size` slices

I have a rank-3 tensor named tensor of shape [batch_size, axis_1, axis_2] and want to split it into batch_size slices along the first axis like so:我有一个名为[batch_size, axis_1, axis_2]tensor的 3 级张tensor ,并希望将它沿第一个轴拆分为batch_size切片,如下所示:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)

Unfortunately, this doesn't work because the value of batch_size isn't yet known during construction of the graph.不幸的是,这是行不通的,因为在构建图的过程中还不知道batch_size的值。

How can I solve this?我该如何解决这个问题?

I get this error:我收到此错误:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.

Weirdly, trying to use batch_size in other TensorFlow functions seems to work:奇怪的是,尝试在其他 TensorFlow 函数中使用batch_size似乎有效:

tensor = tf.reshape(tensor, [batch_size, -1])

works fine despite the fact that the value of batch_size is unknown during graph construction.尽管在图构建期间batch_size的值未知,但工作正常。

Is the problem particularly with tf.split() ?问题是否特别与tf.split()

A work-around is to do:解决方法是:

batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
                        elems=tf.range(batch_size),
                        dtype=tf.float32)

I'm still interested in better solutions though.不过,我仍然对更好的解决方案感兴趣。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM