简体   繁体   English

如何切片张量的一部分?

[英]How to slice a part of tensor?

I want to slice [3.0 ,33.0].I have tried to access this slice by following code.我想切片 [3.0 ,33.0]。我尝试通过以下代码访问此切片。 I'm not so clear about tf.slice command.我不太清楚 tf.slice 命令。 I'm not so clear about begin and size mentioned in documentaion about this command.我不太清楚关于此命令的文档中提到的开始和大小。 Can someone please make it easy to understand.有人可以让它容易理解。

batch = tf.constant([
  [#First image
    [[0.0,10.0],[1.0,11.0]],
    [[3.0,33.0],[4.0,44.0]]
  ],
  [#Second image
    [[5.0,55.0],[6.0,66.0]],
    [[7.0,77.0],[8.0,88.0]]
  ]
])
slice1 = tf.slice(batch,[0,0,0,0], [0,0,1,0]) 
sess = tf.InteractiveSEssion()
sess.run(tf.initialize_all_variables())
print slice1.eval()

I will explain your code with examples, so I created some cases, but first of all I'll explain you tf.slice(input, begin, size) parametters:我将用示例解释您的代码,因此我创建了一些案例,但首先我将向您解释tf.slice(input, begin, size)参数:

  • input is a ref to a Tensor. input是对张量的引用。
  • begin is the index from the slice begins. begin是切片开始的索引。
  • size is the offset of the slice. size是切片的偏移量。

So tf.slice works selecting from input a sub-Tensor that starts at begin index and end at begin + size , treating begin and size as index vectors.因此tf.slice工作从input选择一个子张量,该子张量从begin index begin并在begin + size结束,将beginsize视为索引向量。 The example below will clarify this:下面的例子将阐明这一点:

batch = tf.constant([
        [#First image
            [
                [0.0,10.0],
                [1.0,11.0]
            ],
            [
                [3.0,33.0],
                [4.0,44.0]
            ]
        ],
        [#Second image
            [
                [5.0,55.0],
                [6.0,66.0]
            ],
            [
                [7.0,77.0],
                [8.0,88.0]
            ]
        ]
    ])
slice1 = tf.slice(batch,[0,0,0,0], [1,1,1,1]) 
slice2 = tf.slice(batch,[0,1,0,0], [1,1,2,2]) 
slice3 = tf.slice(batch,[1,1,1,0], [1,1,1,2]) 
slice4 = tf.slice(batch,[0,0,0,0], [2,2,2,2]) 
sess = tf.InteractiveSession()
print("slice1: \n" + str(slice1.eval()) + "\n")
print("slice2: \n" + str(slice2.eval()) + "\n")
print("slice3: \n" + str(slice3.eval()) + "\n")
print("slice4: \n" + str(slice4.eval()) + "\n")

The outputs in this case are:这种情况下的输出是:

slice1: 
[[[[ 0.]]]]

slice2: 
[[[[  3.  33.]
   [  4.  44.]]]]

slice3: 
[[[[  8.  88.]]]]

slice4: 
[[[[  0.  10.]
   [  1.  11.]]

  [[  3.  33.]
   [  4.  44.]]]


 [[[  5.  55.]
   [  6.  66.]]

  [[  7.  77.]
   [  8.  88.]]]]
  • slice1 selects the first element of the Tensor because of it begins on [0,0,0,0] and picks only one element. slice1选择张量的第一个元素,因为它从[0,0,0,0]开始并且只选择一个元素。
  • slice2 selects the first element of the Tensor because of it begins on [0,1,0,0] and picks 1 element in the two first dimensions and 2 in three and four dimensions. slice2选择张量的第一个元素,因为它从[0,1,0,0]开始并在两个第一维中选择 1 个元素,在三个和四个维度中选择 2 个元素。
  • slice3 selects the first element of the Tensor because of it begins on [1,1,1,0] and picks only 1 element in the three first dimensions and 2 in the last. slice3选择 Tensor 的第一个元素,因为它从[1,1,1,0]开始,并且在三个第一个维度中只选择 1 个元素,在最后一个维度中选择 2 个元素。
  • slice4 selects all the element of the Tensor because of it begins on [0,0,0,0] and two elements by dimension, so it covers all the Tensor slice4选择了 Tensor 的所有元素,因为它开始于[0,0,0,0]和两个维度的元素,所以它覆盖了所有的 Tensor

Note that de number of dimensions are the same in all slides.请注意,所有幻灯片中的维度数都相同。 If you one to remove dimensions with only one element you can use tf.squeeze .如果您只删除一个元素的维度,您可以使用tf.squeeze

As how to slice a tensor is well explained above, I will show the trick about how to slice every element in the same position as [3.0, 33.0] in the tensor here (which is a similar problem that I met)由于上面已经很好地解释了如何切片张量,我将在这里展示如何在与[3.0, 33.0]相同位置的张量中切片每个元素的技巧(这是我遇到的类似问题)

batch = tf.constant([
  [#First image
    [[0.0,10.0],[1.0,11.0]],
    [[3.0,33.0],[4.0,44.0]]
  ],
  [#Second image
    [[5.0,55.0],[6.0,66.0]],
    [[7.0,77.0],[8.0,88.0]]
  ]
])
batch_shape = batch.shape
batch_sliced = tf.slice(batch,(0,1,0,0),(batch_shape[0],1,1,batch_shape[-1]))

Then you will get batch_sliced as然后你会得到batch_sliced作为

<tf.Tensor: shape=(2, 1, 1, 2), dtype=float32, numpy=
array([[[[ 3., 33.]]],
       [[[ 7., 77.]]]]```

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM