簡體   English   中英

Tensorflow stratified_sample錯誤

[英]Tensorflow stratified_sample error

我正在嘗試在tf.contrib.training.stratified_sample中使用tf.contrib.training.stratified_sample來平衡類。 我在下面做了一個快速的例子來測試它,以平衡的方式從兩個不平衡的類中抽取樣本並驗證它,但是我收到了一個錯誤。

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes

batch_size = 10
data = ['a']*9990+['b']*10
labels = [1]*9990+[0]*10
data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string)
label_tensor = ops.convert_to_tensor(labels)
target_probs = [0.5,0.5]
data_batch, label_batch = tf.contrib.training.stratified_sample(
    data_tensor, label_tensor, target_probs, batch_size,
    queue_capacity=2*batch_size)

with tf.Session() as sess:
    d,l = sess.run(data_batch,label_batch)
print('percentage "a" = %.3f' % (np.sum(l)/len(l)))

我得到的錯誤是:

Traceback (most recent call last):   
File "/home/jason/code/scrap.py", line 56, in <module>
    test_stratified_sample()   
File "/home/jason/code/scrap.py", line 47, in test_stratified_sample
    queue_capacity=2*batch_size)   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/contrib/training/python/training/sampling_ops.py", line 191, in stratified_sample
    with ops.name_scope(name, 'stratified_sample', tensors + [labels]):   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/ops/math_ops.py", line 829, in binary_op_wrapper
    y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 676, in convert_to_tensor
    as_ref=False)   File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 741, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 113, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 102, in constant
    tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 374, in make_tensor_proto
    _AssertCompatible(values, dtype)   
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 302, in _AssertCompatible
    (dtype.name, repr(mismatch), type(mismatch).__name__)) TypeError: Expected string, got list containing Tensors of type '_Message' instead.

錯誤並不能解釋我做錯了什么。 我還嘗試將原始數據和標簽放入(不轉換為張量),並嘗試使用tf.train.slice_input_producer創建數據的初始隊列並標記張量。

有沒有人得到stratified_sample才能工作? 我找不到任何例子。

我已經將代碼修改為適合我的東西。 變更摘要:

  • 使用enqueue_many=True將一批具有不同標簽的示例排入enqueue_many=True 否則它期望一個標量標簽Tensor(當由隊列運行者評估時可以是隨機的)。
  • 第一個論點預計將是一個Tensors列表。 它應該有一個更好的錯誤消息(我認為這是你遇到的)。 請在Github上發送拉取請求或打開問題以獲得更好的錯誤消息。
  • 啟動隊列運行器。 否則使用隊列的代碼將死鎖。 或者使用EstimatorMonitoredSession因此您無需擔心這一點。
  • (根據評論編輯) stratified_sample不會隨機播放數據,它只接受/拒絕! 因此,如果您的數據未被隨機化,請在采樣之前考慮將其放入slice_input_producerenqueue_many=False )或shuffle_batchenqueue_many=True ),如果您希望它以隨機順序出現。

修改后的代碼(根據Jason的評論改進):

import numpy
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes

with tf.Graph().as_default():
  batch_size = 100
  data = ['a']*9000+['b']*1000
  labels = [1]*9000+[0]*1000
  data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string)
  label_tensor = ops.convert_to_tensor(labels, dtype=dtypes.int32)
  shuffled_data, shuffled_labels = tf.train.slice_input_producer(
      [data_tensor, label_tensor], shuffle=True, capacity=3*batch_size)
  target_probs = numpy.array([0.5,0.5])
  data_batch, label_batch = tf.contrib.training.stratified_sample(
      [shuffled_data], shuffled_labels, target_probs, batch_size,
      queue_capacity=2*batch_size)

  with tf.Session() as session:
    tf.local_variables_initializer().run()
    tf.global_variables_initializer().run()
    coordinator = tf.train.Coordinator()
    tf.train.start_queue_runners(session, coord=coordinator)
    num_iter = 10
    sum_ones = 0.
    for _ in range(num_iter):
      d, l = session.run([data_batch, label_batch])
      count_ones = l.sum()
      sum_ones += float(count_ones)
      print('percentage "a" = %.3f' % (float(count_ones) / len(l)))
    print('Overall: {}'.format(sum_ones / (num_iter * batch_size)))
    coordinator.request_stop()
    coordinator.join()

輸出:

percentage "a" = 0.480
percentage "a" = 0.440
percentage "a" = 0.580
percentage "a" = 0.570
percentage "a" = 0.580
percentage "a" = 0.520
percentage "a" = 0.480
percentage "a" = 0.460
percentage "a" = 0.390
percentage "a" = 0.530
Overall: 0.503

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM