简体   繁体   中英

TensorFlow estimator with custom model_fn runs forever in training phase

Relevant parts of the code below. The call to the Scattering object returns a 3D tensor of coefficients, based on fixed filter maps. The program only enters and returns from the Scattering call once, indicating that the code hangs forever somewhere in the first training step, but not in the Scattering call. Where could this be happening?

def my_model_fn(features, labels, mode, params):
    M, N = features.get_shape().as_list()[-2:]
    scattering_coefficients = Scattering(M=M, N=N, J=1, L=2)(features)
    batch_size = scattering_coefficients.get_shape().as_list()[0]
    # throw all coefficients into single vector for each image
    scattering_coefficients = tf.reshape(scattering_coefficients, [batch_size, -1])
    # returns tensor of correct shape
    print(scattering_coefficients)
    n_classes = 10
    n_coefficients = scattering_coefficients.get_shape().as_list()[1]

    # use linear classifier
    W = tf.Variable(tf.zeros([n_coefficients, n_classes]))
    b = tf.Variable(tf.zeros([n_classes]))
    y_predict = tf.nn.softmax(tf.matmul(scattering_coefficients, W) + b)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions={"predictions": y_predict})

    # loss function and training step
    cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_predict) )
    train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(cross_entropy)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=cross_entropy,
        train_op=train_op)


def sample_batch(X, y, batch_size):
    idx = np.random.choice(X.shape[0], batch_size, replace=False)
    return tf.convert_to_tensor(X[idx]), tf.convert_to_tensor(y[idx])

LEARNING_RATE = 0.01
BATCH_SIZE = 2
n_training_steps = 2
image_dimension = 28
model_params = {"learning_rate": LEARNING_RATE}

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

X_train = mnist.train.images.astype(np.float32)
X_train = normalize(X_train)
# number of channels is 1, -1 infers number of samples
X_train = X_train.reshape(-1, 1, image_dimension, image_dimension)
y_train = mnist.train.labels.astype(np.int64)

X_validation = mnist.validation.images.astype(np.float32)
X_validation = normalize(X_validation)
X_validation = X_validation.reshape(-1, 1, image_dimension, image_dimension)
y_validation = mnist.validation.labels.astype(np.int64)

train_input_fn = lambda: sample_batch(X_train, y_train, BATCH_SIZE)
validation_input_fn = lambda: sample_batch(X_validation, y_validation, BATCH_SIZE)

# Train
scattering_classifier = tf.estimator.Estimator(model_fn=my_model_fn, params=model_params)
# Hangs forever...
scattering_classifier.train(input_fn=train_input_fn, max_steps=n_training_steps)
# If I comment out training step, this finishes immediately.
print("start scoring accuracy")
predictions = scattering_classifier.predict(input_fn=validation_input_fn)

Changing

train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(cross_entropy)

to

train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(
    cross_entropy, global_step=tf.train.get_global_step())

solves the problem. Explanations are very welcome.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM