繁体   English   中英

tf.data.Dataset 的一个热编码标签

[英]one hot encode labels of tf.data.Dataset

我正在尝试将 tf.data.Dataset 的标签转换为一个热编码标签。 我正在使用这个数据集。 我在列中添加了标题(情感、文本),其他一切都是原创的。

这是我用来将标签(正面、负面、中性)编码为一个热点 (3,) 的代码:

def _map_func(text, labels):
   labels_enc = []
   for label in labels:
      if label=='negative':
         label = -1
      elif label=='neutral':
         label = 0
      else: 
         label = 1

      label = tf.one_hot(
         label, 3, name='label', axis=-1)

      labels_enc.append(label)

   return text, labels_enc

raw_train_ds = tf.data.experimental.make_csv_dataset(
   './data/sentiment_data/train.csv', BATCH_SIZE, column_names=['sentiment', 'text'],
   label_name='sentiment', header=True
)

train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)

train_ds = train_ds.map(_map_func)

我收到错误消息: ValueError: Value [<tf.Tensor 'while/label:0' shape=(3,) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (1, 3).

_map_func(text, label) label 的第二个参数的形状为 (64,) type=string。

如果我正确理解了 tensorflows tf.data.Dataset.map function,它会创建一个新数据集,其中包含转换 ZC1C425268E68385D1AB5074C17A94F14 应用的转换。 但是由于错误指出标签的列不能从具有一个字符串的列转换为具有包含 3 个浮点数的列表的列。 有没有办法强制新列的类型接受编码标签?

谢谢您的帮助:)

映射 function 适用于每个元素,因此您无需创建列表并遍历批处理项。 仅对一个样本进行尝试:

def _map_func(text, label):
    if label=='negative':
        label = -1
    elif label=='neutral':
        label = 0
    else: 
        label = 1

    label = tf.one_hot(label, 3, name='label', axis=-1)

   return text, label

我通过使用 TensorFlow TensorArray 解决了这个问题,如下所示:

def _map_func(text, labels):
    i=0
    labels_enc = tf.TensorArray(tf.float32, size=0, dynamic_size=True,
        clear_after_read=False)
    for label in labels:
        if label=='negative':
            label = tf.constant(-1)
        elif label=='neutral':
            label = tf.constant(0)
        else: 
            label = tf.constant(1)

        label = tf.one_hot(
            label, 3, name='label', axis=-1)

        labels_enc.write(i, label)
            i = i+1

    return text, labels_enc.concat()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM