简体   繁体   中英

How correctly export_savedmodel() for TensorFlow WALS Matrix Factorization estimator?

I am using this WALS Matrix Factorization module on TensorFlow. After fitting the estimator, I'm trying to save the model using the export_savedmodel() method but I am unable to provide the correct serving_input_fn argument. The code is here:

from tensorflow.contrib.factorization.python.ops import wals as wals_lib

# dense input array that shows the user X item interactions
# here set as dummy array
dense_array = np.ones((10,10))
num_rows, num_cols = dense_array.shape
emebedding_dim = 5 # manually setting hidden factor

factorizer = wals_lib.WALSMatrixFactorization(num_rows, num_cols, embedding_dim, max_sweeps=10)

# this generate_input_fn() is not shown here but it's a copy of
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/factorization/python/ops/wals_test.py#L82
input_fn, _, _ = generate_input_fn(np_matrix=dense_array, batch_size=32, mode=model_fn.ModeKeys.TRAIN)
factorizer.fit(input_fn, steps=10)

# MY PROBLEM IS HERE
# How to define the correct serving input function?
factorizer.export_savedmodel('path/to/save/model', serving_input_fn=???)

The tricky part here is that the WALS module is, I believe, using an older paradigm of TensorFlow, where the serving_input_fn argument is expecting a callable function that returns an InputFnOps . However, the more updated Estimators, such as this one, expects a function that returns a tf.estimator.export.ServingInputReceiver or tf.estimator.export.TensorServingInputReceiver . I admit I'm not completely fluent in input functions of TensorFlow yet, but any help for my specific use case of saving my WALS estimator will be greatly appreciated. Thanks!

You're right about the deprecated parts of Tensorflow 1.x. For the WALSMatrixFactorization, your serving_input_fn needs to return an InputFnOps object. A correct input function would be therefore:

def serving_input_receiver_fn():
    # some example input
    receiver_tensors = {'my_input': tf.placeholder(dtype=tf.string, shape=[None, 1], name='foo')}
    # some example feature that completely ignores the input
    features = {
        WALSMatrixFactorization.INPUT_ROWS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[3, 4]),
        WALSMatrixFactorization.INPUT_COLS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[3., 4.], dense_shape=[3, 4]),
        WALSMatrixFactorization.PROJECT_ROW: tf.constant(True),
    }
    return tf.contrib.learn.utils.InputFnOps(features, None, receiver_tensors)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM