简体   繁体   English

如何使用tf.data.Dataset和tf.keras进行多输入和多输出?

[英]How to use tf.data.Dataset and tf.keras do multi-inputs and multi-outpus?

I meet one question about multi-output using tf.keras to build model and also using tf.data.Dataset as the input pipeline. 我遇到一个有关使用tf.keras构建模型以及还将tf.data.Dataset作为输入管道的多输出的问题。 Here is my code below: 这是我的代码如下:

  a = tf.keras.layers.Input(shape=(368, 368, 3))
  conv1 = tf.keras.layers.Conv2D(64, 3, 1)(a)
  conv2 = tf.keras.layers.Conv2D(64, 3, 1)(conv1)
  maxpool = tf.keras.layers.MaxPooling2D(pool_size=8, strides=8, 
   padding='same')(conv2)
  conv3 = tf.keras.layers.Conv2D(5, 1, 1)(maxpool)
  conv4 = tf.keras.layers.Conv2D(6, 1, 1)(maxpool)

  inputs = a
  outputs = [conv3, conv4]

  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)


  model.compile(optimizer=tf.keras.optimizers.SGD(),
          loss=tf.keras.losses.mean_squared_error)


  import numpy as np
  data = np.random.rand(10, 368, 368, 3)
  cpm  = np.random.rand(10, 46, 46, 5)
  paf  = np.random.rand(10, 46, 46, 6)

  dataset1 = tf.data.Dataset.from_tensor_slices((data))
  dataset2 = tf.data.Dataset.from_tensor_slices((cpm, paf))
  dataset1 = dataset1.batch(10).repeat()
  dataset2 = dataset2.batch(10).repeat()

  dataset  = tf.data.Dataset.zip((dataset1, dataset2))

  model.fit(dataset, epochs=200, steps_per_epoch=30)

I'm using tensorflow==1.10.1 and I got error like this: 我正在使用tensorflow == 1.10.1并且出现如下错误:

 File "/home/ulsee/work/tensorflow-HalfBodyPose/learnkeras.py", line 123, in <module>
model.fit(dataset, epochs=200, steps_per_epoch=30)
 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/training.py", line 1278, in fit
validation_split=validation_split)
 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/training.py", line 917, in _standardize_user_data
exception_prefix='target')
 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 143, in standardize_input_data
data = [standardize_single_array(x) for x in data]
 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 143, in <listcomp>
data = [standardize_single_array(x) for x in data]
 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 81, in standardize_single_array
elif x.ndim == 1:
AttributeError: 'tuple' object has no attribute 'ndim'

Update : I have made this code work after i upgrading tf==1.11.0. 更新 :升级tf == 1.11.0后,我已使此代码起作用。 So maybe i think it's a version bug. 所以也许我认为这是一个版本错误。

One thing you can try is to concatenate your outputs and then do the same for your target numpy arrays. 您可以尝试的一件事是连接输出,然后对目标numpy数组执行相同的操作。 Whether it makes sense for your application logic wise is something I am not sure. 我不确定这对您的应用程序逻辑是否有意义。

def conc_op(tensors):
    return K.concatenate(tensors) # K refers to Keras backend

def conc_op_shape(input_shapes):
    shape1 = list(input_shapes[0])
    shape2 = list(input_shapes[1])
    return tuple(shape1[:-1], shape1[-1]+shape2[-1])

a = tf.keras.layers.Input(shape=(368, 368, 3))
conv1 = tf.keras.layers.Conv2D(64, 3, 1)(a)
conv2 = tf.keras.layers.Conv2D(64, 3, 1)(conv1)
maxpool = tf.keras.layers.MaxPooling2D(pool_size=8, strides=8, padding='same')(conv2)
conv3 = tf.keras.layers.Conv2D(5, 1, 1)(maxpool)
conv4 = tf.keras.layers.Conv2D(6, 1, 1)(maxpool)

inputs = a
outputs = [conv3, conv4]
conc_outputs = Lambda(conc_op, output_shape=conc_op_shape)(outputs) # This is a keras layer
model = tf.keras.models.Model(inputs=inputs, outputs=conc_outputs)

model.compile(optimizer=tf.keras.optimizers.SGD(), loss=keras.losses.mean_squared_error)
model.summary()
data = np.random.rand(10, 368, 368, 3)
cpm  = np.random.rand(10, 46, 46, 5)
paf  = np.random.rand(10, 46, 46, 6)
label = np.concatenate((cpm, paf), axis=-1)

dataset = tf.data.Dataset.from_tensor_slices((data, label))
dataset = dataset.batch(2).repeat()
model.fit(dataset.make_one_shot_iterator(), epochs=2, steps_per_epoch=5) # Just to check if it runs

Returns the Result: 返回结果:

Epoch 1/2
5/5 [==============================] - 15s 3s/step - loss: 0.4057
Epoch 2/2
5/5 [==============================] - 0s 32ms/step - loss: 0.2282

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM