簡體   English   中英

Tensorflow,如何存儲變量?

[英]Tensorflow, how to store variables?

所以我剛剛開始對 tensorflow 進行一些實驗,但我覺得我很難掌握這個概念,我目前專注於 MNIST 數據集,但只有 8000 個用作訓練,2000 個用於測試。 我目前擁有的小代碼片段是:

from keras.layers import Input, Dense, initializers
from keras.models import Model
from Dataset import Dataset
import matplotlib.pyplot as plt
from keras import optimizers, losses
import tensorflow as tf
import keras.backend as K

#global variables
d = Dataset()
num_features = d.X_train.shape[1]
low_dim = 32

def autoencoder():
    w = initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)
    input = Input(shape=(num_features,))

    encoded = Dense(low_dim, activation='relu', kernel_initializer = w)(input)

    decoded = Dense(num_features, activation='sigmoid', kernel_initializer = w)(encoded)

    autoencoder = Model(input, decoded)
    adam = optimizers.Adagrad(lr=0.01, epsilon=None, decay=0.0)
    autoencoder.compile(optimizer=adam, loss='binary_crossentropy')
    autoencoder.fit(d.X_train, d.X_train,
                    epochs=50,
                    batch_size=64,
                    shuffle=True,
                    )

    encoded_imgs = autoencoder.predict(d.X_test)
    decoded_imgs = autoencoder.predict(encoded_imgs)
    #sess = tf.InteractiveSession()
    #error = losses.mean_absolute_error(decoded_imgs[0], d.X_train[0])
    #print(error.eval())
    #print(decoded_imgs.shape)
    #sess.close()
    n = 20  # how many digits we will display
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # display original
        #sess = tf.InteractiveSession()
        error = losses.mean_absolute_error(decoded_imgs[n], d.X_test[n])
        #print(error.eval())
        #print(decoded_imgs.shape)
        #sess.close()
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(d.X_test[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    #print(error)
    plt.show()
    return error

我想要做的是將錯誤存儲為一個列表,稍后我可以將其打印或繪制在圖表中,但是如何使用 tensorflow/keras 有效地做到這一點? 提前致謝

您可以使用回調 CSVLogger 將錯誤存儲在 csv 文件中。 這是此任務的代碼片段。

from keras.callbacks import CSVLogger

# define callbacks
callbacks = [CSVLogger(path_csv_logger, separator=';', append=True)]

# pass callback to model.fit() oder model.fit_generator()
model.fit_generator(
    train_batch, train_steps, epochs=10, callbacks=callbacks,
    validation_data=validation_batch, validation_steps=val_steps)

編輯:為了在列表中存儲錯誤,你可以使用這樣的東西

# source https://keras.io/callbacks/
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM