簡體   English   中英

使用 tf.data.Dataset.from_generator() 從生成器加載數據

[英]Loading data from generator using tf.data.Dataset.from_generator()

我想為我的度量學習加載數據 model,生成 function 的數據是get_data() function

def get_data():
    def my_generator():
        for i in range(10):
            anchor = list(np.expand_dims(cv2.imread('img1'), axis=0))
            positive = list(np.expand_dims(cv2.imread('img2'), axis=0)
            true = 0
            a = (true, anchor, positive)
            yield a

    return tf.data.Dataset.from_generator(
        my_generator,
        output_types=(tf.int64, tf.Tensor, tf.Tensor),
        output_shapes=(1, (1, 256, 256, 3), (1, 256, 256, 3))
    )

dataset = get_data()

當我運行此代碼時,出現以下錯誤。 我嘗試將其他一些 arguments 傳遞給output_types ,例如 tf.float64 ,但它也不起作用。 我想我對形狀做錯了什么,但我不知道是什么。

TypeError:無法將值 <class 'tensorflow.python.framework.ops.Tensor'> 轉換為 TensorFlow DType。

任何幫助表示贊賞:)

正如我所想,問題出在形狀上,這對我有用

    return tf.data.Dataset.from_generator(
        my_generator,
        output_types=(tf.float64, tf.float64, tf.float64),
        output_shapes=(tf.TensorShape(None), tf.TensorShape((1, 256, 256, 3)), 
            tf.TensorShape((1, 256, 256, 3))))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM