簡體   English   中英

使用tensorflow估計器創建多輸入

[英]what use tensorflow estimator create multi-input

對不起,我的英語不好=。=

我創建了一個keras模型,並使用tf.keras.estimator.model_to_estimator轉換為estimator,但是該模型是多輸入的,我該如何創建數據集來提供數據?

這是我的型號代碼:

model = VGG19(include_top=False, input_shape=(182, 182 , 3))
y = model.output
y = keras.layers.Flatten()(y)
y = keras.layers.Dense(512, activation='relu')(y)
y = keras.layers.Dense(256, activation='relu')(y)
y = keras.layers.Dense(128, activation='relu')(y)
model = keras.Model(inputs=model.input, outputs=y)

input_image = keras.layers.Input(shape=(182, 182, 3))
input_anchor = keras.layers.Input(shape=(182, 182, 3))
out_image = model(input_image)
out_anchor = model(input_anchor)

out = keras.layers.concatenate([out_image, out_anchor])
out = keras.layers.Dense(1, activation='sigmoid')(out)
img_model = keras.Model([input_image, input_anchor], out)

face_model.compile(optimizer=tf.train.AdamOptimizer(1e-4, loss='binary_crossentropy', metrics=['accuracy'])

distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(num_gpus_per_worker=0)
config = tf.estimator.RunConfig(model_dir='/home/zjq/test/image_model.h5', train_distribute=distribution)

est_model = tf.keras.estimator.model_to_estimator(keras_model=image_model, config=config)

現在,我有一個輸入列表,形狀為[(100000,182,182,3),(100000,182,182,3),(100000,1)],如何定義輸入函數返回tf.data。數據集?

首先,輸入名稱占位符:

input_image = keras.layers.Input(shape=(182, 182, 3),name='image')
input_anchor = keras.layers.Input(shape=(182, 182, 3),name='anchor')

如果您的輸入數據是train_data並且形狀為[(100000, 182, 182, 3), (100000, 182, 182, 3), (100000, 1)] train_data [(100000, 182, 182, 3), (100000, 182, 182, 3), (100000, 1)] train_data [(100000, 182, 182, 3), (100000, 182, 182, 3), (100000, 1)] train_data [(100000, 182, 182, 3), (100000, 182, 182, 3), (100000, 1)] ,則執行以下操作:

BATCH_SIZE = 512
EPOCHS = 4

def input_fn(data, epochs, batch_size):
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices(({'image':data[0],'anchor':data[1]}, data[2]))
    # Shuffle, repeat, and batch the examples.
    SHUFFLE_SIZE = 1000
    dataset = dataset.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size)
    dataset = dataset.prefetch(2)
    # Return the dataset.
    return dataset
est_model.train(lambda :input_fn(train_data,EPOCHS,BATCH_SIZE))

可以根據您的需要調整參數BATCH_SIZEEPOCHSSHUFFLE_SIZE

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM