简体   繁体   中英

How do I bundle two tensors together in Tensorflow Batches?

I want to create a loss that's simultaneously on a reconstruction of X and a function of the labels y in tensorflow. The two need to match, and I'm using tf.data.Dataset.batch() and gradient tape instead of directly calling .fit(, batch = number) such as you might usually use. To solve the problem I thought of a couple of approaches:

  1. Bundle together X and Y as a tuple and turn that tuple into a tensorflow dataset and try to unpack after using .batch()
  2. Just tack the y on to the end of the X tensor and separate the two later after using .batch()

Is there any standard way to achieve this kind of thing? I'm not sure if the two above approaches are hacky, but it seems like number 2 might work at least but I'm also wondering if I've greatly overcomplicated the process. My data loading setup for the x's only at the moment looks something like this:

train_dataset = (tf.data.Dataset.from_tensor_slices(train_dataset)
                 .shuffle(len(training_ind))
                 .batch(bsize))

And what I'm envisioning is something like:

train_x_dataset = tf.data.Dataset.from_tensor_slices(train_x)
train_y_dataset = tf.data.Dataset.from_tensor_slices(train_y)

train_datset = (tf.data.Dataset((train_x_dataset,train_y_dataset))
                     .shuffle(len(training_ind))
                     .batch(bsize))

for train_x, train_y in train_dataset:
  loss(train_step(model, train_x, train_y, optimizer))

This was solved much easier than I expected. Essentially I already had the basic idea of passing the two datasets as tuples, the only difference is that you have to pass both WHEN you call .from_tensor_slices , not after. Then accessing the elements are pretty much as expected.

train_dataset = tf.data.Dataset.from_tensor_slices((train_x,train_y))

train_dataset = (train_dataset
                     .shuffle(n_obs)
                     .batch(bsize))

for train_x, train_y in train_dataset:
  loss(train_step(model, train_x, train_y, optimizer))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM