简体   繁体   English

从Tensorflow中的多个tf.data.Datasets中随机抽样

[英]Randomly sample from multiple tf.data.Datasets in Tensorflow

suppose I have N tf.data.Datasets and a list of N probabilities (summing to 1), now I would like to create dataset such that the examples are sampled from the N datasets with the given probabilities. 假设我有N个 tf.data.Datasets和N个概率列表(求和为1),现在我想创建数据集,以便从具有给定概率的N个数据集中采样示例。

I would like this to work for arbitrary probabilities -> simple zip/concat/flatmap with fixed number of examples from each dataset is probably not what I am looking for. 我希望这适用于任意概率 - >简单的zip / concat / flatmap以及每个数据集中固定数量的示例可能不是我想要的。

Is it possible to do this in TF? 有可能在TF中这样做吗? Thanks! 谢谢!

As of 1.12, tf.data.experimental.sample_from_datasets provides this functionality: https://www.tensorflow.org/api_docs/python/tf/data/experimental/sample_from_datasets 从1.12开始, tf.data.experimental.sample_from_datasets提供此功能: httpstf.data.experimental.sample_from_datasets

EDIT: Looks like in earlier versions this can be accessed by tf.contrib.data.sample_from_datasets 编辑:在早期版本中,这可以通过tf.contrib.data.sample_from_datasets访问

if p is a Tensor of probabilities (or unnormalized relative probabilities) where p[i] is the probability that dataset i is chosen, you can use tf.multinomial in conjunction with tf.contrib.data.choose_from_datasets : 如果p是概率Tensor (或非标准化相对概率),其中p[i]是选择数据集i的概率,则可以将tf.multinomialtf.contrib.data.choose_from_datasets结合使用:

# create some datasets and their unnormalized probability of being chosen
datasets = [
    tf.data.Dataset.from_tensors(['a']).repeat(),
    tf.data.Dataset.from_tensors(['b']).repeat(),
    tf.data.Dataset.from_tensors(['c']).repeat(),
    tf.data.Dataset.from_tensors(['d']).repeat()]
p = [1., 2., 3., 4.]  # unnormalized

# random choice function
def get_random_choice(p):
  choice = tf.multinomial(tf.log([p]), 1)
  return tf.cast(tf.squeeze(choice), tf.int64)

# assemble the "choosing" dataset
choice_dataset = tf.data.Dataset.from_tensors([0])  # create a dummy dataset
choice_dataset = choice_dataset.map(lambda x: get_random_choice(p))  # populate it with random choices
choice_dataset = choice_dataset.repeat()  # repeat

# obtain your combined dataset, assembled randomly from source datasets
# with the desired selection frequencies. 
combined_dataset = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

Note that the dataset needs to be initialized (you can't use a simple make_one_shot_iterator): 请注意,需要初始化数据集(不能使用简单的make_one_shot_iterator):

choice_iterator = combined_dataset.make_initializable_iterator()
choice = choice_iterator.get_next()
with tf.Session() as sess:
  sess.run(choice_iterator.initializer)
  print ''.join([sess.run(choice)[0] for _ in range(20)])

>> ddbcccdcccbbddadcadb

我认为您可以使用tf.contrib.data.rejection_resample来实现目标分发。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 从 tf.data.Datasets 构建 tf,estimator.DNNClassifier - Build tf,estimator.DNNClassifier from tf.data.Datasets 使用tf.data.Datasets冻结Tensorflow图时确定输入节点 - Determining input nodes when freezing Tensorflow graphs using tf.data.Datasets 如何为线性回归和训练模型创建 tf.data.Datasets - How to create a tf.data.Datasets for linear regression and train model 将一个 tf.data.Datasets 与另一个的所有其他元素合并 - Merging one tf.data.Datasets with every other element of another one tf.data.experimental.sample_from_datasets 未按预期采样 - tf.data.experimental.sample_from_datasets not sampling as expected 如何使用 tf.data.Dataset.map 对两个 tf.data.Datasets 进行元素总和,两者都无限迭代? - How to do the element-wise sum of two tf.data.Datasets, both iterating indefinitely, with tf.data.Dataset.map? 使用 TensorFlow 和 tf.data.Dataset 从文件夹中采样一系列图像 - Sample a sequence of images from a folder with TensorFlow and tf.data.Dataset 在 tf.data.Datasets.from_generator 中将多个参数传递给 Generator - Passing Multiple Arguments to Generator in tf.data.Datasets.from_generator tf.data.Dataset.map() 用于由多个切片组成的数据集 - tf.data.Dataset.map() for datasets made from multiple slices 如何在tensorflow中顺序和随机地从tf.data.Iterator读取? - How to read from tf.data.Iterator both sequentially and randomly in tensorflow?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM