简体   繁体   中英

Tensorflow switch case in batch dimension with Conv2d

Suppose I have a batch with shape [n, h, w, c] as well as a list of n indices in range of 0-9 and 10 Conv2D convs , that I want to apply to the data depending on the index in the list. The indices list changes with every batch.

Eg. with input x, batch size 4 and indices l=[1,5,1,9] I would like to compute [convs[l[0]](x[0]), convs[l[1]](x[1]), convs[l[2]](x[2]), convs[l[3]](x[3])]

A naive solution would be to compute every combination and gather based on l . However, this requires 10 times the amount of memory. Is there a better solution to this problem?

One "hacky" solution would be to expand the dimension of the input from [n, h, w, c] to [1, n, h, w, c] , then use Conv3D instead with kernel shape [1, x, y]

If you have the weights defined separately (weights can also be obtained using layer.weights ), you could similarly stack them in the 0th dimension and use them through tf.nn.conv3d .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM