简体   繁体   English

如何更改特定索引中的 tf.Dataset 对象的值

[英]How to change the values of a tf.Dataset object in a specific index

The structure of my tf.data.Dataset object is as follow.我的 tf.data.Dataset 对象的结构如下。 ((3, 400, 1), (3, 400, 1)) ((3, 400, 1), (3, 400, 1))

I would like to divide the elements in the 3rd row, of each element by 10. My code is as follows.我想将每个元素的第 3 行中的元素除以 10。我的代码如下。 But it complains as NumPy arrays are immutable (I'd like to use map )但它抱怨 NumPy 数组是不可变的(我想使用map

def alternate_row (dataset):
  xx, yy = [], [] 
  for x, y in dataset.as_numpy_iterator():
    x[2] /= 10
    y[2] /= 10
    xx.append(x)
    yy.append(y)

  return xx, yy

Try using tf.data.Dataset.map and tf.concat :尝试使用tf.data.Dataset.maptf.concat

import tensorflow as tf

samples = 5
x1 = tf.random.normal((samples, 3, 400, 1))
x2 = tf.random.normal((samples, 3, 400, 1))

dataset = tf.data.Dataset.from_tensor_slices((x1, x2))

def divide(x1, x2):
  x1 = tf.concat([x1[:2], x1[2:] / 10], axis=0)
  x2 = tf.concat([x2[:2], x2[2:] / 10], axis=0)
  return x1, x2

dataset = dataset.map(divide)

Note that I assume you want to change the values in the second dimension of the tensors, but you can change the notation for the slice to suit your needs.请注意,我假设您想要更改张量的第二维中的值,但您可以更改切片的符号以满足您的需要。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM