簡體   English   中英

如何在 TensorFlow 的張量切片數據集中填充序列?

[英]how to pad sequences in a tensor slice dataset in TensorFlow?

我有一個由兩個參差不齊的張量組成的張量切片數據集。

張量_a 就像: <tf.RaggedTensor [[3, 3, 5], [3, 3, 14, 4, 17, 20], [3, 14, 22, 17]]>

tensor_b 就像: <tf.RaggedTensor [[-1, 1, -1], [-1, -1, 1, -1, -1, -1], [-1, 1, -1, 2]]>

(tensor_a 和 tensor_b 的索引相同,長度相同。)

我制作了數據集

dataset = tf.data.Dataset.from_tensor_slices((tensor_a, tensor_b))
dataset

<TensorSliceDataset element_spec=(RaggedTensorSpec(TensorShape([None]), tf.int64, 0, tf.int64), RaggedTensorSpec(TensorShape([None]), tf.int32, 0, tf.int64))>

如何填充數據集中的序列? 我試過tf.padtf.keras.preprocessing.sequence.pad_sequences但沒有找到正確的方法。

你可以嘗試這樣的事情:

import tensorflow as tf

tensor_a = tf.ragged.constant([[3, 3, 5], [3, 3, 14, 4, 17, 20], [3, 14, 22, 17]])
tensor_b = tf.ragged.constant([[-1, 1, -1], [-1, -1, 1, -1, -1, -1], [-1, 1, -1, 2]])
dataset = tf.data.Dataset.from_tensor_slices((tensor_a, tensor_b))

max_length = max(list(dataset.map(lambda x, y: tf.shape(x)[0])))

def pad(x, y):
  x = tf.concat([x, tf.zeros((int(max_length-tf.shape(x)[0]),), dtype=tf.int32)], axis=0)
  y = tf.concat([y, tf.zeros((int(max_length-tf.shape(y)[0]),), dtype=tf.int32)], axis=0)
  return x, y

dataset = dataset.map(pad)
for x, y in dataset:
  print(x, y)
tf.Tensor([3 3 5 0 0 0], shape=(6,), dtype=int32) tf.Tensor([-1  1 -1  0  0  0], shape=(6,), dtype=int32)
tf.Tensor([ 3  3 14  4 17 20], shape=(6,), dtype=int32) tf.Tensor([-1 -1  1 -1 -1 -1], shape=(6,), dtype=int32)
tf.Tensor([ 3 14 22 17  0  0], shape=(6,), dtype=int32) tf.Tensor([-1  1 -1  2  0  0], shape=(6,), dtype=int32)

對於預填充,只需調整pad function:

def pad(x, y):
  x = tf.concat([tf.zeros((int(max_length-tf.shape(x)[0]),), dtype=tf.int32), x], axis=0)
  y = tf.concat([tf.zeros((int(max_length-tf.shape(y)[0]),), dtype=tf.int32), y], axis=0)
  return x, y
tf.Tensor([0 0 0 3 3 5], shape=(6,), dtype=int32) tf.Tensor([ 0  0  0 -1  1 -1], shape=(6,), dtype=int32)
tf.Tensor([ 3  3 14  4 17 20], shape=(6,), dtype=int32) tf.Tensor([-1 -1  1 -1 -1 -1], shape=(6,), dtype=int32)
tf.Tensor([ 0  0  3 14 22 17], shape=(6,), dtype=int32) tf.Tensor([ 0  0 -1  1 -1  2], shape=(6,), dtype=int32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM