![](/img/trans.png)
[英]Parametrized generators while using tf.data.Dataset.from_generator()
[英]Loading data from generator using tf.data.Dataset.from_generator()
我想為我的度量學習加載數據 model,生成 function 的數據是get_data()
function
def get_data():
def my_generator():
for i in range(10):
anchor = list(np.expand_dims(cv2.imread('img1'), axis=0))
positive = list(np.expand_dims(cv2.imread('img2'), axis=0)
true = 0
a = (true, anchor, positive)
yield a
return tf.data.Dataset.from_generator(
my_generator,
output_types=(tf.int64, tf.Tensor, tf.Tensor),
output_shapes=(1, (1, 256, 256, 3), (1, 256, 256, 3))
)
dataset = get_data()
當我運行此代碼時,出現以下錯誤。 我嘗試將其他一些 arguments 傳遞給output_types
,例如 tf.float64 ,但它也不起作用。 我想我對形狀做錯了什么,但我不知道是什么。
TypeError:無法將值 <class 'tensorflow.python.framework.ops.Tensor'> 轉換為 TensorFlow DType。
任何幫助表示贊賞:)
正如我所想,問題出在形狀上,這對我有用
return tf.data.Dataset.from_generator(
my_generator,
output_types=(tf.float64, tf.float64, tf.float64),
output_shapes=(tf.TensorShape(None), tf.TensorShape((1, 256, 256, 3)),
tf.TensorShape((1, 256, 256, 3))))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.