繁体   English   中英

张量流馈送列表功能(多热)到tf.estimator

[英]tensorflow feed list feature (multi-hot) to tf.estimator

某些功能列的数据类型是list 它们的长度可以不同。 我想将此列编码为多热门分类功能并将其提供给tf.estimator 我尝试了以下但错误Unable to get element as bytes显示Unable to get element as bytes 我认为这是深度学习中的常见做法,尤其是推荐系统,例如Deep&Wide模型。 我在这里找到了一个相关的问题但它没有显示如何向估算器提供信息。

import pandas as pd
import tensorflow as tf

OUTDIR = "./data"

data = {"x": [["a", "c"], ["a", "b"], ["b", "c"]], "y": ["x", "y", "z"]}
df = pd.DataFrame(data)

Y = df["y"]
X = df.drop("y", axis=1)

indicator_features = [
    tf.feature_column.indicator_column(
        categorical_column=tf.feature_column.categorical_column_with_vocabulary_list(
            key="x", vocabulary_list=["a", "b", "c"]
        )
    )
]

model = tf.estimator.LinearClassifier(
    feature_columns=indicator_features, model_dir=OUTDIR
)

training_input_fn = tf.estimator.inputs.pandas_input_fn(
    x=X, y=Y, batch_size=64, shuffle=True, num_epochs=None
)

model.train(input_fn=training_input_fn)

以下错误:

信息:tensorflow:使用默认配置。 INFO:tensorflow:使用config:{'_model_dir':'testalg','_ tf_random_seed':无,'_ save_summary_steps':100,'_ save_checkpoints_steps':无,'_ _ save_checkpoints_secs':600,'_ session_config':无,'_ keep_checkpoint_max': 5,'_ keep_checkpoint_every_n_hours':10000,'_ log_step_count_steps':100,'_ train_distribute':无,'_ device_fn':无,'_ service':无,'_ cluster_pec':,'_ task_type':'worker','_ task_id':0 ,'_ global_id_in_cluster':0,'_ master':'','_ evaluation_master':'','_ is_chief':是的,'_ num_ps_replicas':0,'_ num_worker_replicas':1} INFO:tensorflow:调用model_fn。 信息:tensorflow:完成调用model_fn。 信息:tensorflow:创建CheckpointSaverHook。 信息:tensorflow:图表已完成。 信息:tensorflow:运行local_init_op。 信息:tensorflow:完成运行local_init_op。 INFO:tensorflow:向协调器报告错误:,无法将元素作为字节获取。 INFO:tensorflow:将0的检查点保存到testalg / model.ckpt中。 -------------------------------------------------- ----- InternalError Traceback(最近一次调用最后一次)/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self,fn,* args )1321尝试: - > 1322返回fn(* args)1323除了errors.OpError为e:

_hun_fn中的/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py(feed_dict,fetch_list,target_list,options,run_metadata)1306 return self._call_tf_sessionrun( - > 1307 options,feed_dict,fetch_list,target_list,run_metadata)1308

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self,options,feed_dict,fetch_list,target_list,run_metadata)1408 self._session,options, feed_dict,fetch_list,target_list, - > 1409 run_metadata)1410 else:

InternalError:无法将元素作为字节获取。

在处理上述异常期间,发生了另一个异常:

()中的InternalError Traceback(最近一次调用最后)44 45 ---> 46 model.train(input_fn = training_input_fn)

火车中的/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py(self,input_fn,hooks,steps,max_steps,saving_listeners)364 365 saving_listeners = _check_listeners_type(saving_listeners) ) - > 366 loss = self._train_model(input_fn,hooks,saving_listeners)367 logging.info('最后一步的损失:%s。',损失)368返回自我

_train_model中的/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py(self,input_fn,hooks,saving_listeners)1117
返回self._train_model_distributed(input_fn,hooks,saving_listeners)1118否则: - > 1119返回self._train_model_default(input_fn,hooks,saving_listeners)1120 1121 def _train_model_default(self,input_fn,hooks,saving_listeners):

_train_model_default中的/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py(self,input_fn,hooks,saving_listeners)
1133返回self._train_with_estimator_spec(estimator_spec,worker_hooks,1134)
hooks,global_step_tensor, - > 1135 saving_listeners)1136 1137 def _train_model_distributed(self,input_fn,hooks,saving_listeners):

_home/ininan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py in _train_with_estimator_spec(self,estimator_spec,worker_hooks,hooks,global_step_tensor,saving_listeners)1334 loss = None 1335 while mon_sess.should_stop(): - > 1336 _,loss = mon_sess.run([estimator_spec.train_op,estimator_spec.loss])1337回损1338

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in exit (self,exception_type,exception_value,traceback)687如果[errors.OutOfRangeError,StopIteration]中的exception_type :688 exception_type =无- > 689 self._close_internal(exception_type)690# 退出应返回true来抑制异常。 691返回exception_type为None

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in _close_internal(self,exception_type)724 if self._sess is None:725 raise RuntimeError('Session is已经关闭。') - > 726 self._sess.close()727 finally:728 self._sess = None

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in close(self)972 if self._sess:973 try: - > 974 self._sess。 close()975除了_PREEMPTION_ERRORS:976传球

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in close(self)1116 self._coord.join(1117
stop_grace_period_secs = self._stop_grace_period_secs, - > 1118 ignore_live_threads = True)1119 finally:1120试试:

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py in join(self,threads,stop_grace_period_secs,ignore_live_threads)387 self._registered_threads = set()388 if self ._exc_info_to_raise: - > 389 six.reraise(* self._exc_info_to_raise)390 elif stragglers:391 if ignore_live_threads:

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/six.py in reraise(tp,value,tb)683 value = tp()684 if value。 回溯不是tb: - > 685提升value.with_traceback(tb)686提高值687

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/inputs/queues/feeding_queue_runner.py in _run(self,sess,enqueue_op,feed_fn,coord)92 try:93 feed_dict =如果feed_fn为None则为None否则feed_fn()---> 94 sess.run(enqueue_op,feed_dict = feed_dict)95除外(errors.OutOfRangeError,errors.CancelledError):96#此异常表示队列已关闭。

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py在运行中(self,fetches,feed_dict,options,run_metadata)898尝试:899 result = self._run (无,提取,feed_dict,options_ptr, - > 900 run_metadata_ptr)901如果run_metadata:902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

_hun中的/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py(self,handle,fetches,feed_dict,options,run_metadata)
1133如果final_fetches或final_targets或(handle和feed_dict_tensor):1134 results = self._do_run(handle,final_targets,final_fetches, - > 1135 feed_dict_tensor,options,run_metadata)1136 else:1137 results = []

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self,handle,target_list,fetch_list,feed_dict,options,run_metadata)1314 if handle is None :1315返回self._do_call(_ run_fn,feeds,fetches,targets,options, - > 1316 run_metadata)1317 else:1318 return self._do_call(_prun_fn,handle,feeds,fetches)

/home/yinan.li1/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self,fn,* args)1333除了KeyError:1334 pass - > 1335 raise type( e)(node_def,op,message)1336 1337 def _extend_graph(self):

InternalError:无法将元素作为字节获取。

我认为你的情况中的一个问题是pandas中的列类型实际上是对象而不是字符串。 如果您将其转换为单独的字符串列,您将摆脱此错误。 请记住, The basic TensorFlow tf.string dtype allows you to build tensors of byte strings. 并且当您在此列中存储对象而不是字符串时,您会收到错误。

下面的代码将克服您上面得到的错误,但它不会完全解决您的问题。 列表的变量长度必须通过填充或列表或类似的东西来处理,因为indicator_column可能在处理缺失值时遇到问题。

X2= pd.DataFrame(X['x'].values.tolist(), columns=['x1','x2'])

feat1 = tf.feature_column.categorical_column_with_vocabulary_list(
            key="x1", vocabulary_list=["a", "b", "c"]
        )
feat2 = tf.feature_column.categorical_column_with_vocabulary_list(
            key="x2", vocabulary_list=["a", "b", "c"]
        )
indicator_features = [
    tf.feature_column.indicator_column(
        categorical_column=feat1
    ),tf.feature_column.indicator_column(
        categorical_column=feat2
    )
]

training_input_fn = tf.estimator.inputs.pandas_input_fn(
    x=X2, y=Y, batch_size=64, shuffle=True, num_epochs=None
)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM