簡體   English   中英

切片張量列表-TensorFlow

[英]Slicing tensor with list - TensorFlow

有沒有辦法在Tensorflow中完成這種切片方法(使用numpy顯示示例)?

z = np.random.random((3,7,7,12))
x = z[...,[0,5]]

這樣

x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3)
assert np.all(x == x_hat)
x.shape # (3, 7, 7, 2)

在Tensorflow中,此操作

tfz = tf.constant(z)
i = np.array([0,5] dtype=np.int32)
tfx = tfz[...,i]

引發錯誤

ValueError: Shapes must be equal rank, but are 0 and 1
From merging shape 0 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [2].

您需要重塑形狀以使串聯結果與原始形狀(前三個尺寸)一致。

z = np.arange(36)
tfz = tf.reshape(tf.constant(z), [2, 3, 2, 3])
slice1 = tf.reshape(tfz[:,:,:,1], [2, 3, -1, 1])
slice2 = tf.reshape(tfz[:,:,:,2], [2, 3, -1, 1])
slice = tf.concat([slice1, slice2], axis=3)

with tf.Session() as sess:
  print sess.run([tfz, slice])


> [[[[ 0,  1,  2],
     [ 3,  4,  5]],

    [[ 6,  7,  8],
     [ 9, 10, 11]],

    [[12, 13, 14],
     [15, 16, 17]]],

   [[[18, 19, 20],
     [21, 22, 23]],

    [[24, 25, 26],
     [27, 28, 29]],

    [[30, 31, 32],
     [33, 34, 35]]]]

  # Get the last two columns
> [[[[ 1,  2],
     [ 4,  5]],

    [[ 7,  8],
     [10, 11]],

    [[13, 14],
     [16, 17]]],

   [[[19, 20],
     [22, 23]],

    [[25, 26],
     [28, 29]],

    [[31, 32],
     [34, 35]]]]

如Greeness所說,這是形狀錯誤。 不幸的是,似乎沒有一種像我希望的那樣簡單的方法,但這是我想出的通用解決方案:

def list_slice(tensor, indices, axis):
    """
    Args
    ----
    tensor (Tensor) : input tensor to slice
    indices ( [int] ) : list of indices of where to perform slices
    axis (int) : the axis to perform the slice on
    """

    slices = []   

    ## Set the shape of the output tensor. 
    # Set any unknown dimensions to -1, so that reshape can infer it correctly. 
    # Set the dimension in the slice direction to be 1, so that overall dimensions are preserved during the operation
    shape = tensor.get_shape().as_list()
    shape[shape==None] = -1
    shape[axis] = 1

    nd = len(shape)

    for i in indices:   
        _slice = [slice(None)]*nd
        _slice[axis] = slice(i,i+1)
        slices.append(tf.reshape(tensor[_slice], shape))

    return tf.concat(slices, axis=axis)



z = np.random.random(size=(3, 7, 7, 12))
x = z[...,[0,5]]
tfz = tf.constant(z)
tfx_hat = list_slice(tfz, [0, 5], axis=3)
x_hat = tfx_hat.eval()

assert np.all(x == x_hat)

怎么樣:

x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1) 

這對我有用:

z = np.random.random((3,7,7,12))
tfz = tf.constant(z)
x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1)

x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3)

with tf.Session() as sess:
    x_run = sess.run(x)

assert np.all(x_run == x_hat)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM