I am using this WALS Matrix Factorization module on TensorFlow. After fitting the estimator, I'm trying to save the model using the export_savedmodel() method but I am unable to provide the correct serving_input_fn argument. The code is here:
from tensorflow.contrib.factorization.python.ops import wals as wals_lib
# dense input array that shows the user X item interactions
# here set as dummy array
dense_array = np.ones((10,10))
num_rows, num_cols = dense_array.shape
emebedding_dim = 5 # manually setting hidden factor
factorizer = wals_lib.WALSMatrixFactorization(num_rows, num_cols, embedding_dim, max_sweeps=10)
# this generate_input_fn() is not shown here but it's a copy of
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/factorization/python/ops/wals_test.py#L82
input_fn, _, _ = generate_input_fn(np_matrix=dense_array, batch_size=32, mode=model_fn.ModeKeys.TRAIN)
factorizer.fit(input_fn, steps=10)
# MY PROBLEM IS HERE
# How to define the correct serving input function?
factorizer.export_savedmodel('path/to/save/model', serving_input_fn=???)
The tricky part here is that the WALS module is, I believe, using an older paradigm of TensorFlow, where the serving_input_fn
argument is expecting a callable function that returns an InputFnOps
. However, the more updated Estimators, such as this one, expects a function that returns a tf.estimator.export.ServingInputReceiver
or tf.estimator.export.TensorServingInputReceiver
. I admit I'm not completely fluent in input functions of TensorFlow yet, but any help for my specific use case of saving my WALS estimator will be greatly appreciated. Thanks!
You're right about the deprecated parts of Tensorflow 1.x. For the WALSMatrixFactorization, your serving_input_fn
needs to return an InputFnOps
object. A correct input function would be therefore:
def serving_input_receiver_fn():
# some example input
receiver_tensors = {'my_input': tf.placeholder(dtype=tf.string, shape=[None, 1], name='foo')}
# some example feature that completely ignores the input
features = {
WALSMatrixFactorization.INPUT_ROWS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[3, 4]),
WALSMatrixFactorization.INPUT_COLS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[3., 4.], dense_shape=[3, 4]),
WALSMatrixFactorization.PROJECT_ROW: tf.constant(True),
}
return tf.contrib.learn.utils.InputFnOps(features, None, receiver_tensors)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.