繁体   English   中英

如何使用具有多输出的 tf.data.dataset 运行 tf.keras.fit?

[英]How to run tf.keras.fit with tf.data.dataset with multi-output?

我正在使用tf2.0并想使用tf.kerastf.data.dataset来训练网络。 但是,我在将 tf.keras.fit 与具有多输出的 tf.data.dataset 一起使用时遇到了困难。

我随机一个database来模拟这种情况。

database将返回形状 (32,32,3),(8),(2)

这是我的代码

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

x = tf.random.normal((256,32,32,3))
y1 = tf.random.normal((256,8))
y2 = tf.random.normal((256,2))

database = tf.data.Dataset.from_tensor_slices((x,y1,y2))
database = database.batch(32).repeat()

resnet50 = keras.applications.resnet.ResNet50(include_top=False,weights='imagenet',input_shape=(32,32,3))

#resnet50.summary()
resnet50.trainable = False

model = keras.Sequential([
    resnet50,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(8,activation='softmax',name='output1')
    ])

#model.summary()
IMGSZ = 32
input_image = keras.layers.Input((IMGSZ,IMGSZ,3),dtype=tf.float32)
x = model.layers[0](model.input)
x = keras.layers.GlobalAveragePooling2D()(x)
output = keras.layers.Dense(2,activation='softmax',name='output2')(x)
con = keras.layers.concatenate([model.output,output],axis=-1)
print(con.shape)
model2 = keras.models.Model(inputs=model.input,outputs=[model.output,output])
model2.summary()

base_learning_rate = 0.001
model2.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
             loss=[keras.losses.categorical_crossentropy,keras.losses.binary_crossentropy],
             metrics=['accuracy'])

hsitory = model2.fit(database,epochs=1000,steps_per_epoch=200)

我想让 model 计算两个损耗output1output2

这是我得到的错误

Traceback (most recent call last):
  File "d:/Users/charl/Desktop/work/factory/5777/demo.py", line 39, in <module>
    hsitory = model2.fit(database,epochs=1000,steps_per_epoch=200)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py", line 697, in _initialize
    *args, **kwds))
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py", line 2855, in _get_concrete_function_internal_garbage_collected       
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:806 train_function  *
        return step_function(self, iterator)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:789 run_step  **
        outputs = model.train_step(data)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:749 train_step
        y, y_pred, sample_weight, regularization_losses=self.losses)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\compile_utils.py:204 __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\losses.py:151 __call__
        losses, sample_weight, reduction=self._get_reduction())
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\utils\losses_utils.py:112 compute_weighted_loss
        losses, sample_weight)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\losses\util.py:143 scale_losses_by_sample_weight
        losses, None, sample_weight)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\losses\util.py:95 squeeze_or_expand_dimensions
        sample_weight = array_ops.squeeze(sample_weight, [-1])
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
        return target(*args, **kwargs)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\util\deprecation.py:507 new_func
        return func(*args, **kwargs)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\array_ops.py:4259 squeeze
        return gen_array_ops.squeeze(input, axis, name)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\gen_array_ops.py:10044 squeeze
        "Squeeze", input=input, squeeze_dims=axis, name=name)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\op_def_library.py:744 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\func_graph.py:593 _create_op_internal
        compute_device)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py:3485 _create_op_internal
        op_def=op_def)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py:1975 __init__
        control_input_ops, op_def)
    C:\Users\charl\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py:1815 _create_c_op
        raise ValueError(str(e))

    ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for '{{node categorical_crossentropy/weighted_loss/Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[-1]](IteratorGetNext:2)' with input shapes: [?,2].

我尝试了许多我发现但无法完成的方法和其他方法。 那么有人可以帮助我吗? 非常感谢!

我通过更改来解决问题

database = tf.data.Dataset.from_tensor_slices((x,y1,y2))

database = tf.data.Dataset.from_tensor_slices((x,(y1,y2)))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM