繁体   English   中英

有没有办法知道是否在张量流数据集上使用了`.repeat` /`.batch` /`.shuffle`?

[英]Is there a way of knowing if `.repeat`/`.batch`/`.shuffle` have been used on a tensorflow dataset?

获得一个已经构建的tensorflow数据集对象( tf.data.Dataset )命名data

有没有办法知道是否通过检查数据来调用此对象上的函数repeat / batch / shuffle (并可能得到其他信息,如批处理和重复的参数)

(我假设急切执行)

编辑1:似乎行str方法携带一些信息。 调查那个。

编辑2:属性output_shapes提供有关批量大小和形状的信息。

我能想到的唯一解决方案是进入tensorflow代码。 gen_dataset_ops.py是在从源构建期间生成的,因此只能在本地找到它。

另一个文件是dataset_ops.py ,它可以在下面的链接中找到。 您只需在相关函数返回之前插入print语句。 例如,来自dataset_ops.py shuffle函数:

def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
"""Randomly shuffles the elements of this dataset.
...
print('Dataset shuffled') #inserted print here
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)

数据集对象被包装到DatasetV1Adapter ,因此您无法了解它的进展情况。 渴望模式的唯一区别是它支持显式迭代,但是做类似smth的效率非常低

array = np.random.rand(10)
dataset = tf.data.Dataset.from_tensor_slices(array)
if len([i for i in dataset]) != array.shape[0]:
    print('repeated')

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM