简体   繁体   English

转换为张量后,参差不齐的张量没有 len()

[英]Ragged Tensors have no len() after conversion to Tensor

I am training a deep learning model on stacks of images with variable dimensions.我正在训练具有可变尺寸的图像堆栈的深度学习模型。 (Shape = [Batch, None, 256, 256, 1]) , where None can be variable. (Shape = [Batch, None, 256, 256, 1]) ,其中 None 可以是可变的。

I use tf.RaggedTensor.merge_dimsions(0,1) to convert the ragged Tensor to a shape of [None, 256, 256, 1] to run into a pretrained keras CNN model.我使用tf.RaggedTensor.merge_dimsions(0,1)tf.RaggedTensor.merge_dimsions(0,1)的 Tensor 转换为[None, 256, 256, 1]的形状以运行预训练的 keras CNN 模型。

However, using the KerasLayer API results in the following error: TypeError: the object of type 'RaggedTensor' has no len()但是,使用 KerasLayer API 会导致以下错误: TypeError: the object of type 'RaggedTensor' has no len()

When I apply .merge_dimsions outside of the KerasLayer and pass the tensors to the same pretrained model I do not get this error.当我在.merge_dimsions之外应用.merge_dimsions并将张量传递给相同的预训练模型时,我没有收到此错误。

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, output_signature=(tf.RaggedTensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, ragged_rank=1)))
ds = ds.repeat().batch(8)
print(next(iter(ds)).shape)

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x
x = tf.keras.layers.Lambda(merge)(inputs)
merged_inputs = x
# x = ResNet50(x) # Uncommenting this will result in `model` producing an error when run for inference.

model = tf.keras.Model(inputs, x)

# Run inference
data = next(iter(ds))
model(data).shape # Will be an error if ResNet50 is used

Here is a colab notebook that demonstrates the problem.这是一个演示该问题的 colab 笔记本。 https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing

Not sure if the following answer or workaround is stable for complex network design.不确定以下答案或解决方法对于复杂的网络设计是否稳定。 But here are some pointers.但这里有一些提示。 The reason you got你得到的原因

Ragged Tensors have no len()

is because of ResNet models, as it expects tensor and not ragged_tensor .是因为ResNet模型,因为它期望tensor而不是ragged_tensor I'm not sure however if the ResNet(weights=None) is able to take ragged_tensor or not directly.但是,我不确定ResNet(weights=None)是否能够直接采用ragged_tensor So, if we can convert the ragged data right before the ResNet gets fed, maybe it won't complain.所以,如果我们能在ResNet被输入之前转换参差不齐的数据,也许它就不会抱怨了。 Below is the full working code according to this.下面是根据this的完整工作代码。 But please note, there is probably some efficient approach maybe possible.但请注意,可能有一些有效的方法是可能的。


Data数据

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, 
                                    output_signature=(tf.RaggedTensorSpec(
                                        shape=(None, 256, 256, 1), 
                                        dtype=tf.float32, ragged_rank=1
                                        )
                                    )
                                )
ds = ds.repeat().batch(8)

Basic Model基本型号

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x

Ragged Model衣衫褴褛的模特

Here we convert ragged_tensor to tensor before passing the data to ResNet .在这里,我们在将数据传递给ResNet之前将ragged_tensor转换为tensor

class RagModel(tf.keras.Model):
    def __init__(self):
        super(RagModel, self).__init__()
        # base models 
        self.a = tf.keras.layers.Lambda(merge)
        # convert: tensor = ragged_tensor.to_tensor()
        self.b = tf.keras.layers.Lambda(lambda x: x.to_tensor())
        self.c = ResNet50
    
    def call(self, inputs, training=None, plot=False, **kwargs):
        x = self.a(inputs)
        x = self.b(x) if not plot else x
        x = self.c(x)
        return x
    
    # a helper function to plot 
    def build_graph(self):
        x = tf.keras.Input(type_spec=tf.RaggedTensorSpec(
            shape=(8, None, 256, 256, 1),
            dtype=tf.float32, ragged_rank=1)
        )
        return tf.keras.Model(inputs=[x],
                              outputs=self.call(x, plot=True))
   
x_model = RagModel()

Run

data = next(iter(ds)); print(data.shape)
x_model(data).shape 
(8, None, 256, 256, 1)
TensorShape([39, 1000])

Plot阴谋

tf.keras.utils.plot_model(x_model.build_graph(), 
              show_shapes=True, show_layer_names=True)

在此处输入图片说明

x_model.build_graph().summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(8, None, 256, 256, 1)]  0         
_________________________________________________________________
lambda_2 (Lambda)            (None, 256, 256, 1)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 1000)              25630440  
=================================================================
Total params: 25,630,440
Trainable params: 25,577,320
Non-trainable params: 53,120
_________________________________________________________________

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM