I'm trying to create a collaborative filtering algorithm to suggest products to certain users.
I started shortly and started working with TensorFlow (I thought it was sufficiently effective and flexible). I found this code that does what I'm interested in, creates the model and train the user IDs, products, and ratings: https://github.com/songgc/TF-recomm
I launched the code and trained the model.
After training the model I would need to make the predictions, that is, get suggestions for each user so that they can be saved in a DB from which I access with a NODE.js application.
How do I retrieve this list of suggestions for each user when the training is done?
if __name__ == '__main__':
df_train, df_test=get_data()
svd(df_train, df_test)
print("Done!")
You can run
predict_result = sess.run(inter_op, feed_dict={user_batch:users, item_batch:items})
which users means all user ids and items for all item ids, and predict_result is the scores of each user for all items, you can store the predict_result into DB;
You need to modify the prediction part of your code to output top K
recommended products. The current code where the prediction is made is :
embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user")
embd_item = tf.nn.embedding_lookup(w_item, item_batch, name="embedding_item")
infer = tf.reduce_sum(tf.multiply(embd_user, embd_item), 1)
Here the embed_user
is the user embeddings of a particular user and embd_item
is for the particular item. So instead of comparing a particular user
with a particular item
, you need to change it to compare it to all items. The matrix w_item
is the embeddings of all items. This can be done by:
embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user")
# Multiply user embedding of shape: [1 x dim]
# with every item embeddings of shape: [item_num, dim],
# to produce rank of all items of shape: [item_num]
predict = tf.matmul(embd_user, w_item, transpose_b=True)
Then you can select the top k
index of the maximum in the predicted output.
-Gabriele Picco 我的 tensorflow 建议系统有问题,我可以在哪里与您联系?
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.