简体   繁体   中英

tensorflow, mini-batch, tf.placeholder - read state of nodes at given iteration

I want to print the value of MSE at each epoch/batch combination. the code below reports the tensor object representing the mse instead of its value at each iteration:

print("Epoch", epoch, "Batch_Index", batch_index, "MSE:", mse)

Example line of output:

Epoch 0 Batch_Index 0 MSE: Tensor("mse_2:0", shape=(), dtype=float32)

I understand it is because MSE is referencing tf.placeholder nodes which by themselves do not have any data. But once I run the below code:

sess.run(training_op, feed_dict={X: X_batch, y: y_batch})

the data should be already available thus values for all nodes depending on that data should be accessible as well, I think requesting an evaluation of the MSE in the print statement results in error

print("Epoch", epoch, "Batch_Index", batch_index, "MSE:", mse.eval())

Output2:

InvalidArgumentError: You must feed a value for placeholder tensor 'X_2' with dtype float and shape [?,9] ...

This tells me that mse.eval() does not see the data defined in sess.run()

Why do we experience such behavior? How should we change the code to make it report MSA at each specified iteration?

import numpy as np
from sklearn.datasets import fetch_california_housing

housing = fetch_california_housing()
m, n = housing.data.shape
housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data] # ADD COLUMN OF 1s for BIAS!

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_housing_data = scaler.fit_transform(housing.data)
scaled_housing_data_plus_bias = np.c_[np.ones((m, 1)), scaled_housing_data]

X = tf.placeholder(tf.float32, shape=(None, n + 1), name="X")
y = tf.placeholder(tf.float32, shape=(None, 1), name="y")

theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0, seed=42), name="theta")

y_pred = tf.matmul(X, theta, name="predictions")
error = y_pred - y
mse = tf.reduce_mean(tf.square(error), name="mse")

optimizer =  tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(mse)
init = tf.global_variables_initializer()

n_epochs = 100
batch_size = 100
n_batches = int(np.ceil(m / batch_size))
learning_rate = 0.01

def fetch_batch(epoch, batch_index, batch_size):
    np.random.seed(epoch * n_batches + batch_index)  # not shown in the book
    indices = np.random.randint(m, size=batch_size)  # not shown
    X_batch = scaled_housing_data_plus_bias[indices] # not shown
    y_batch = housing.target.reshape(-1, 1)[indices] # not shown
    return X_batch, y_batch

with tf.Session() as sess:
    sess.run(init)

    for epoch in range(n_epochs):
        for batch_index in range(n_batches):
            X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
            if (epoch % 50 == 0 and batch_index % 100 == 0):
                print("Epoch", epoch, "Batch_Index", batch_index, "MSE:", mse)
    best_theta = theta.eval()

best_theta

First, I think this kind of debugging and printing and stuff is much easier to do with eager execution enabled in tensorflow.

Without eager execution enabled, "print" in tensorflow will never print the dynamic value of a tensor; it'll only print the name of the tensor, which is rarely what you want. Instead, use tf.Print to inspect the runtime value of the tensor (by doing something like tensor = tf.Print(tensor, [tensor]) as tf.Print does not execute unless its output is used somewhere).

i made it work by modifying the print statement to the following:

print("Epoch", epoch, "Batch_Index", batch_index, "MSE:", mse.eval(feed_dict={X: scaled_housing_data_plus_bias, y: housing_target}))

moreover by referencing complete data set (not batches) i was able to test the generalization of the current batch-based model to the whole sample. It should be easy to extend it to test on the test and hold-out samples as training of the model progresses

i am afraid that such on-the-fly evaluation (even on batches) can have impact on performance of the model. I will do further tests of that.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM