简体   繁体   English

TensorFlow Keras 层中的重新排序轴

[英]Reorder axis in TensorFlow Keras layer

I am building a model that applies a random shuffle to data along the first non batch axis, applies a series of Conv1Ds, then applies the inverse of the shuffle.我正在构建一个模型,该模型沿第一个非批处理轴对数据应用随机洗牌,应用一系列 Conv1D,然后应用洗牌的逆。 Unfortunately the tf.gather layer messes up the batch dimension None , and i'm not sure why.不幸的是tf.gather层弄乱了批处理维度None ,我不知道为什么。

Below is an example of what happens.下面是发生的情况的示例。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dim = 90
input_img = keras.Input(shape=(dim, 4))

# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)

# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)

model = keras.models.Model(
   inputs=[input_img],
   outputs=tensor,
)

Here the summary is as follows:总结如下:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)           [(None, 90, 4)]           0         
_________________________________________________________________
lambda_51 (Lambda)           (90, 90, 4)               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Whereas I want the output shape of lambda_51 to be (None, 90, 4) .而我希望lambda_51的输出形状为(None, 90, 4)

Try to wrap input_img and order into a list when you pass them to tensor layer.当您将input_imgorder传递到tensor层时,尝试将它们包装到列表中。

In this way tensor layer becomes:这样tensor层就变成了:

tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])

and your summary:和你的总结:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 90, 4)]           0         
_________________________________________________________________
lambda_3 (Lambda)            (None, 90, 4)             0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM