[英]Tensorflow stratified_sample error
我正在尝试在tf.contrib.training.stratified_sample
中使用tf.contrib.training.stratified_sample来平衡类。 我在下面做了一个快速的例子来测试它,以平衡的方式从两个不平衡的类中抽取样本并验证它,但是我收到了一个错误。
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
batch_size = 10
data = ['a']*9990+['b']*10
labels = [1]*9990+[0]*10
data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string)
label_tensor = ops.convert_to_tensor(labels)
target_probs = [0.5,0.5]
data_batch, label_batch = tf.contrib.training.stratified_sample(
data_tensor, label_tensor, target_probs, batch_size,
queue_capacity=2*batch_size)
with tf.Session() as sess:
d,l = sess.run(data_batch,label_batch)
print('percentage "a" = %.3f' % (np.sum(l)/len(l)))
我得到的错误是:
Traceback (most recent call last):
File "/home/jason/code/scrap.py", line 56, in <module>
test_stratified_sample()
File "/home/jason/code/scrap.py", line 47, in test_stratified_sample
queue_capacity=2*batch_size)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/contrib/training/python/training/sampling_ops.py", line 191, in stratified_sample
with ops.name_scope(name, 'stratified_sample', tensors + [labels]):
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/ops/math_ops.py", line 829, in binary_op_wrapper
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 676, in convert_to_tensor
as_ref=False) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 741, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 113, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 102, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 374, in make_tensor_proto
_AssertCompatible(values, dtype)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 302, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).__name__)) TypeError: Expected string, got list containing Tensors of type '_Message' instead.
错误并不能解释我做错了什么。 我还尝试将原始数据和标签放入(不转换为张量),并尝试使用tf.train.slice_input_producer
创建数据的初始队列并标记张量。
有没有人得到stratified_sample
才能工作? 我找不到任何例子。
我已经将代码修改为适合我的东西。 变更摘要:
enqueue_many=True
将一批具有不同标签的示例排入enqueue_many=True
。 否则它期望一个标量标签Tensor(当由队列运行者评估时可以是随机的)。 Estimator
或MonitoredSession
因此您无需担心这一点。 stratified_sample
不会随机播放数据,它只接受/拒绝! 因此,如果您的数据未被随机化,请在采样之前考虑将其放入slice_input_producer
( enqueue_many=False
)或shuffle_batch
( enqueue_many=True
),如果您希望它以随机顺序出现。 修改后的代码(根据Jason的评论改进):
import numpy
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
with tf.Graph().as_default():
batch_size = 100
data = ['a']*9000+['b']*1000
labels = [1]*9000+[0]*1000
data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string)
label_tensor = ops.convert_to_tensor(labels, dtype=dtypes.int32)
shuffled_data, shuffled_labels = tf.train.slice_input_producer(
[data_tensor, label_tensor], shuffle=True, capacity=3*batch_size)
target_probs = numpy.array([0.5,0.5])
data_batch, label_batch = tf.contrib.training.stratified_sample(
[shuffled_data], shuffled_labels, target_probs, batch_size,
queue_capacity=2*batch_size)
with tf.Session() as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()
coordinator = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coordinator)
num_iter = 10
sum_ones = 0.
for _ in range(num_iter):
d, l = session.run([data_batch, label_batch])
count_ones = l.sum()
sum_ones += float(count_ones)
print('percentage "a" = %.3f' % (float(count_ones) / len(l)))
print('Overall: {}'.format(sum_ones / (num_iter * batch_size)))
coordinator.request_stop()
coordinator.join()
输出:
percentage "a" = 0.480
percentage "a" = 0.440
percentage "a" = 0.580
percentage "a" = 0.570
percentage "a" = 0.580
percentage "a" = 0.520
percentage "a" = 0.480
percentage "a" = 0.460
percentage "a" = 0.390
percentage "a" = 0.530
Overall: 0.503
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.