簡體   English   中英

TensorFlow WALS 矩陣分解估計器的 export_savedmodel() 如何正確?

[英]How correctly export_savedmodel() for TensorFlow WALS Matrix Factorization estimator?

我在 TensorFlow 上使用這個WALS 矩陣分解模塊。 擬合估算器后,我嘗試使用 export_savedmodel() 方法保存模型,但無法提供正確的 services_input_fn 參數。 代碼在這里:

from tensorflow.contrib.factorization.python.ops import wals as wals_lib

# dense input array that shows the user X item interactions
# here set as dummy array
dense_array = np.ones((10,10))
num_rows, num_cols = dense_array.shape
emebedding_dim = 5 # manually setting hidden factor

factorizer = wals_lib.WALSMatrixFactorization(num_rows, num_cols, embedding_dim, max_sweeps=10)

# this generate_input_fn() is not shown here but it's a copy of
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/factorization/python/ops/wals_test.py#L82
input_fn, _, _ = generate_input_fn(np_matrix=dense_array, batch_size=32, mode=model_fn.ModeKeys.TRAIN)
factorizer.fit(input_fn, steps=10)

# MY PROBLEM IS HERE
# How to define the correct serving input function?
factorizer.export_savedmodel('path/to/save/model', serving_input_fn=???)

這里棘手的部分是,我相信 WALS 模塊使用的是 TensorFlow 的舊范式,其中serving_input_fn參數期望返回一個InputFnOps的可調用函數。 但是,更新程度更高的 Estimator,例如這個,需要一個返回tf.estimator.export.ServingInputReceivertf.estimator.export.TensorServingInputReceiver 我承認我還沒有完全精通 TensorFlow 的輸入函數,但是對於我保存 WALS 估算器的特定用例的任何幫助將不勝感激。 謝謝!

您對 Tensorflow 1.x 已棄用的部分是正確的。 對於 WALSMatrixFactorization,您的serving_input_fn需要返回一個InputFnOps對象。 因此,正確的輸入函數是:

def serving_input_receiver_fn():
    # some example input
    receiver_tensors = {'my_input': tf.placeholder(dtype=tf.string, shape=[None, 1], name='foo')}
    # some example feature that completely ignores the input
    features = {
        WALSMatrixFactorization.INPUT_ROWS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[3, 4]),
        WALSMatrixFactorization.INPUT_COLS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[3., 4.], dense_shape=[3, 4]),
        WALSMatrixFactorization.PROJECT_ROW: tf.constant(True),
    }
    return tf.contrib.learn.utils.InputFnOps(features, None, receiver_tensors)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM