[英]TensorFlow create a dataset using generator for multiple columns with different data types
I know that I can use tensorflow.data.TextLineDataset
for this but I'd like to write a customized function to create a DataSet
from a generator. 我知道我可以为此使用tensorflow.data.TextLineDataset
,但是我想编写一个自定义函数以从生成器创建DataSet
。
I'm implementing the input function for the census income data like this 我正在像这样实现人口普查收入数据的输入功能
_CSV_COLUMNS = [
('age', tf.int32),
('workclass', tf.string),
('fnlwgt', tf.int32),
('education', tf.string),
('education_num', tf.int32),
('marital_status', tf.string),
('occupation', tf.string),
('relationship', tf.string),
('race', tf.string),
('gender', tf.string),
('capital_gain', tf.int32),
('capital_loss', tf.int32),
('hours_per_week', tf.int32),
('native_country', tf.string),
('income_bracket', tf.string),
]
def input_csv(data_file, num_epochs, batch_size):
df = pd.read_csv(data_file, header=None)
def gen():
for row in df.iterrows():
row = row[1]
yield dict(zip([n[0] for n in _CSV_COLUMNS[:14]], row[:14])), row[14] == '>50K'
return tf.data.Dataset.from_generator(gen, (dict(_CSV_COLUMNS[:14]), tf.bool))
When I try this function with the Estimator
API, it results in this error: 当我使用Estimator
API尝试此功能时,将导致以下错误:
InvalidArgumentError (see above for traceback): assertion failed: [Feature (key: age) cannot have rank 0. Given: Tensor(\\"IteratorGetNext:0\\", dtype=int32)] [Condition x > 0 did not hold element-wise:] [x (linear/linear_model_1/linear_model/age/Rank:0) = ] [0] InvalidArgumentError(请参见上面的回溯):断言失败:[功能(键:年龄)不能具有等级0。给定:Tensor(\\“ IteratorGetNext:0 \\”,dtype = int32)] [条件x> 0不包含元素-明智的做法:] [x(linear / linear_model_1 / linear_model / age / Rank:0)=] [0]
Any ideas? 有任何想法吗? Thanks in advance. 提前致谢。
Additional info: 附加信息:
I'm testing it with SageMaker local mode. 我正在使用SageMaker本地模式进行测试。 The train_input_fn
and model_fn
are like train_input_fn
和model_fn
就像
_NUMERIC_COLUMNS = [
tf.feature_column.numeric_column(c) for c in
['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
]
def model_fn(features, labels, mode, hyperparameters):
classifier = tf.estimator.LinearClassifier(_NUMERIC_COLUMNS)
return classifier.model_fn(features, labels, mode, None)
def train_input_fn(training_dir, hyperparameters):
return input_csv(os.path.join(training_dir, 'adult.data.csv'), 3, 20)
The traceback is like (I added 2 blank lines around my source.) 追溯就像(我在源代码周围添加了2条空行。)
Caused by op 'linear/linear_model_1/linear_model/age/assert_positive/assert_less/Assert/Assert', defined at:
File "/usr/local/bin/entry.py", line 28, in <module>
modes[mode]()
File "/usr/local/lib/python3.6/site-packages/container_support/training.py", line 36, in start
fw.train()
File "/usr/local/lib/python3.6/site-packages/tf_container/train_entry_point.py", line 164, in train
train_wrapper.train()
File "/usr/local/lib/python3.6/site-packages/tf_container/trainer.py", line 73, in train
tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 451, in train_and_evaluate
return executor.run()
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 617, in run
getattr(self, task_to_run)()
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 654, in run_master
self._start_distributed_training(saving_listeners=saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 767, in _start_distributed_training
saving_listeners=saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 376, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1145, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1170, in _train_model_default
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1133, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tf_container/trainer.py", line 108, in _model_fn
return self.customer_script.model_fn(features, labels, mode, params)
File "/opt/ml/code/train.py", line 32, in model_fn
return classifier.model_fn(features, labels, mode, None)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 263, in public_model_fn
return self._call_model_fn(features, labels, mode, config)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1133, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/canned/linear.py", line 339, in _model_fn
sparse_combiner=sparse_combiner)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/canned/linear.py", line 163, in _linear_model_fn
logits = logit_fn(features=features)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/canned/linear.py", line 101, in linear_logit_fn
cols_to_vars=cols_to_vars)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 464, in linear_model
retval = linear_model_layer(features) # pylint: disable=not-callable
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 736, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 647, in call
weighted_sum = layer(builder)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 362, in __call__
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 736, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 539, in call
weight_var=self._weight_var)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2030, in _create_weighted_sum
weight_var=weight_var)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2043, in _create_dense_column_weighted_sum
trainable=trainable)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2474, in _get_dense_tensor
return inputs.get(self)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2263, in get
transformed = column._transform_feature(self) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2442, in _transform_feature
input_tensor = inputs.get(self.key)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2250, in get
feature_tensor = self._get_raw_feature_as_tensor(key)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py", line 2312, in _get_raw_feature_as_tensor
key, feature_tensor))]):
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/check_ops.py", line 198, in assert_positive
return assert_less(zero, x, data=data, summarize=summarize)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/check_ops.py", line 559, in assert_less
return control_flow_ops.Assert(condition, data, summarize=summarize)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py", line 118, in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 149, in Assert
return gen_logging_ops._assert(condition, data, summarize, name="Assert")
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_logging_ops.py", line 51, in _assert
name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
self._traceback = tf_stack.extract_stack()
You cannot create a Tensor object with different data types. 您不能创建具有不同数据类型的Tensor对象。 Check out the official doc 查看官方文档
You can consider encoding everything as a string as the documentation suggests, or one hot encoding and further preprocessing before converting to a tensor, depending on your application. 您可以根据文档建议将所有内容编码为字符串,也可以考虑将其编码为一种热编码并在转换为张量之前进行进一步的预处理,具体取决于您的应用程序。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.