简体   繁体   中英

Tensorflow - How to use tf.gather() to deliver part of the first layer's inputs to one of the next layer's filter

Let's say I have this setup:

conv1 = tf.layers.conv2d(
    inputs=input_layer,
    filters=4,
    kernel_size=[14, 14],
    padding="valid",
    activation=tf.nn.relu
    )

conv2 = tf.layers.conv2d(
    inputs=conv1,
    filters=16,
    kernel_size=[5, 5],
    padding="valid",
    activation=tf.nn.relu
    )

Like the partial connection scheme in this paper , I want to deliver separate numbers of layers from conv1 to one filter in conv2 . Do I use tf.gather() for this, and how?

tf.gather() makes slices only along one axis, so for your case tf.gather_nd() would work better. So it should be as following:

# make a placeholder for indices of the outputs you will pick, 
# or make it constant if they won't change
indices = tf.placeholder(tf.int32,[None,4]) 

conv1 = tf.layers.conv2d(
    inputs=input_layer,
    filters=4,
    kernel_size=[14, 14],
    padding="valid",
    activation=tf.nn.relu
)

# select required outputs
new_input = tf.gather_nd(conv,indices)
# or you can hard-wire them, if they're constant
new_input = tf.gather_nd(conv, [[0,0,0,0],[1,0,0,0]])

# then you need to reshape it back a proper size 
# as previous operation will return flattened list 
# (unless you slice raws, but not single outputs). 
# Depending what size you got and what you need, but possibly something like that:
required_shape = [-1,10,10,4]
new_input = tf.reshape(new_input,required_shape) 
# or instead of the constant array feed a tensor with new shape as well

conv2 = tf.layers.conv2d(
    inputs=new_input,
    filters=16,
    kernel_size=[5, 5],
    padding="valid",
    activation=tf.nn.relu
)

In case of gather_nd you can specify explicit elements of the array along each axis. There is a good example in the official documentation:

indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]

indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]

indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM