繁体   English   中英

使用 tf.data.Dataset.from_generator() 从生成器加载数据

[英]Loading data from generator using tf.data.Dataset.from_generator()

我想为我的度量学习加载数据 model,生成 function 的数据是get_data() function

def get_data():
    def my_generator():
        for i in range(10):
            anchor = list(np.expand_dims(cv2.imread('img1'), axis=0))
            positive = list(np.expand_dims(cv2.imread('img2'), axis=0)
            true = 0
            a = (true, anchor, positive)
            yield a

    return tf.data.Dataset.from_generator(
        my_generator,
        output_types=(tf.int64, tf.Tensor, tf.Tensor),
        output_shapes=(1, (1, 256, 256, 3), (1, 256, 256, 3))
    )

dataset = get_data()

当我运行此代码时,出现以下错误。 我尝试将其他一些 arguments 传递给output_types ,例如 tf.float64 ,但它也不起作用。 我想我对形状做错了什么,但我不知道是什么。

TypeError:无法将值 <class 'tensorflow.python.framework.ops.Tensor'> 转换为 TensorFlow DType。

任何帮助表示赞赏:)

正如我所想,问题出在形状上,这对我有用

    return tf.data.Dataset.from_generator(
        my_generator,
        output_types=(tf.float64, tf.float64, tf.float64),
        output_shapes=(tf.TensorShape(None), tf.TensorShape((1, 256, 256, 3)), 
            tf.TensorShape((1, 256, 256, 3))))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM