![](/img/trans.png)
[英]How to save custom attributes with custom model in Tensorflow?
[英]how to compile and save the model in custom training tensorflow
我嘗試按照 tensorflow 教程編寫自定義訓練循環。 它給出了 output 如下
Start of epoch 0
Training loss (for one batch) at step 0: 15.9249
Seen so far: 16 samples
Training loss (for one batch) at step 2: 14.9462
Seen so far: 48 samples
Training loss (for one batch) at step 4: 14.6554
Seen so far: 80 samples
Training loss (for one batch) at step 6: 14.1741
Seen so far: 112 samples
Training acc over epoch: 15.1999
Validation acc: 14.5266
Time taken: 8.02s
在自定義訓練循環中,我不知道如何編譯 model,根據以下標准保存最好的 model,例如“如果驗證集上的損失未能減少或連續 10 個周期保持不變,則 model 將保存到 model .h5 文件,訓練將停止。此外,我想將每個時期的訓練損失和驗證損失保存到一個 csv 文件,這可能類似於以下 keras 命令所做的。我希望專家可以幫助我合並幾行執行上述任務的代碼。謝謝。
#save_model_name = 'model_name' +'.h5'
#early_stopping = EarlyStopping(monitor='val_loss', patience=30, verbose=1)
#model_checkpoint = ModelCheckpoint(save_model_name,monitor='val_R2_score',
save_best_only=True, verbose=1, mode='max')
#reduce_lr = ReduceLROnPlateau(factor=0.5, monitor='val_loss',
# patience=15, min_lr=0.000001, verbose=1)
#csv_logger = CSVLogger(model_name +".csv", append=True)
我的代碼是
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn =tf.keras.losses.MeanSquaredError()
batch_size = 16
# dataset.
x_train = np.load('x_train_data.npy')
x_valid = np.load('x_valid_data.npy')
y_train = np.load('y_train_data.npy')
y_valid = np.load('y_valid_data.npy')
#prepare the data for training
x_train = np.expand_dims(x_train, axis=2)
x_valid = np.expand_dims(x_valid, axis=2)
y_train = np.expand_dims(y_train, axis=2)
y_valid = np.expand_dims(y_valid, axis=2)
#prepare the training datasets based on tensorflow
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset based on tensorflow
val_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
val_dataset = val_dataset.batch(batch_size)
train_acc_metric = tf.keras.metrics.MeanSquaredError()
val_acc_metric = tf.keras.metrics.MeanSquaredError()
#model
model = test_model(im_width=1, im_height=80, neurons=16, kern_sz = 20)
model.summary()
######cutom training loop ######
import time
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
losses.append(float(loss_value))
# Log every 200 batches.
if step % 2 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * batch_size))
print(losses)
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
我建議你看看這個: https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
正如它所說:
當您需要自定義fit()
做什么時,您應該覆蓋 Model 的訓練步驟 function class。這是fit()
為每批數據調用的 function。 然后您將能夠像往常一樣調用fit()
- 它將運行您自己的學習算法。
因此,您可以實現一個新的 class,它是tf.keras.Model
的子類,以自定義訓練期間發生的情況。 只需將您的train_step()
和test_step()
添加到您的 class。
您還可以輕松使用 compile 和 fit 方法,根據需要保存損失和所有已編寫的回調。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.