繁体   English   中英

tensorflow:使用tf.estimator和keras进行词汇查询

[英]tensorflow: vocabulary lookup with tf.estimator and keras

我有以下数据集

username,itemname,value
"carl","socks",12.50
"john doe","shirts",30.00
...

我也有以下词汇查询文件

usernames.txt

carl
john doe
bob smith
...

itemnames.txt

socks
shirts
shoes
...

我将在预测时间接收字符串。 没有办法解决。 为了使训练相似,我正在使用tf.contrib.lookup

import tf.contrib.lookup

user_lookup = tf.contrib.lookup.index_table_from_file(
    vocabulary_file='usernames.txt'
)

item_lookup = tf.contrib.lookup.index_table_from_file(
    vocabulary_file='itemnames.txt'
)    

现在,我使用keras api定义了以下模型

import tensorflow as tf

user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32)
item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32)

user_embedding = tf.keras.layers.Embedding(input_dim=num_users, output_dim=10)(user_input)
item_embedding = tf.keras.layers.Embedding(input_dim=num_items, output_dim=10)(item_input)

...
output = ...
model = tf.keras.Model([user_input, item_input], output)
model.compile(...)

我正在使用tf.estimator进行训练和预测。 因此,我的第一个直觉是执行以下操作:

my_estimator = tf.keras.estimator.model_to_estimator(keras_model=model)

tf.tables_initializer()

def train_fn(dataset_iterator):
     (username, itemname), value = dataset_iterator.get_next()
     userid = user_lookup.lookup(username)
     itemid = item_lookup.lookup(itemname)
     return (username, itemname), value

my_train_spec = tf.estimator.TrainSpec(
   input_fn=train_fn(train_data)
)

my_eval_spec = tf.estimator.EvalSpec(
   input_fn=train_fn(validation_data)
)

tf.estimator.train_and_evaluate(
    estimator=my_estimator,
    train_spec=my_train_spec,
    eval_spec=my_eval_spec
)

运行此命令时,出现以下错误:

ValueError: Tensor("Cast_2:0", shape=(), dtype=int32) must be from the same graph as Tensor("Item-Embedding-LMF/embeddings/Read/ReadVariableOp:0", shape=(429099, 10), dtype=float32, device=/job:ps/task:1).

谁能推荐解决此问题的方法? 也许甚至还有其他方法来处理此查找?

通常,查找没有问题。 所有变量都应使用相同的图进行关联。可以将模型写入范围内,例如

def model_fn: 
    with tf.variable_scope('my_model', reuse=tf.AUTO_REUSE):
    ....
    ..
     return estimator 

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM