[英]How to randomly select from set of functions in TensorFlow using tf.function
My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset
and tf.function
API.我的问题是:在预处理期间,我想使用tf.data.Dataset
和tf.function
API 将从一组函数中随机选择的函数应用于数据集示例。
Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions.具体来说,我的数据是 3D 体积,我希望从一组 24 个预定义的旋转函数中应用旋转。 I would like to write this code within a tf.function
so this limits the use of packages like numpy
and list indexing.我想在tf.function
编写此代码,因此这限制了numpy
和列表索引等包的使用。
For example, I would like to do something like this:例如,我想做这样的事情:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
However I cannot index the list_of_funcs
as TypeError: list indices must be integers or slices, not Tensor
.但是我不能将list_of_funcs
索引为TypeError: list indices must be integers or slices, not Tensor
。 Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor
and use tf.gather
.此外,我无法将这些函数(AFAIK)收集到tf.Tensor
并使用tf.gather
。
So my question: how can I reasonably and neatly sample from these functions in a tf.function
?所以我的问题是:如何在tf.function
从这些函数中合理而整齐地采样?
Maybe try using tf.py_function , which:也许尝试使用tf.py_function ,其中:
Wraps a python function into a TensorFlow op that executes it eagerly.将一个 python 函数包装到一个 TensorFlow op 中,它会急切地执行它。
For example:例如:
import tensorflow as tf
import random
@tf.function
def func1(tensor):
print('func1')
return tensor
@tf.function
def func2(tensor):
print('func2')
return tensor
@tf.function
def func3(tensor):
print('func3')
return tensor
@tf.function
def func4(tensor):
print('func4')
return tensor
@tf.function
def apply(tensor):
dispatcher = {
'func1': func1,
'func2': func2,
'func3': func3,
'func4': func4
}
keys = list(dispatcher)
def get_random_function_and_apply(t):
return dispatcher[random.choice(keys)](t)
y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
return y
print(apply(tf.random.normal((5, 5, 5))))
'''
func4
tf.Tensor(
[[[ 0.6041213 -2.054427 1.1755397 -0.62914884 -0.00978021]
[ 0.06134182 -1.5529596 -0.3429052 -0.03199977 -1.1796658 ]
[-0.65084136 -1.5009187 -0.43266404 -0.18494445 1.2958355 ]
[-1.6614605 -0.7398612 1.5384725 -0.24926051 -0.5075399 ]
[ 0.7781286 -0.4102168 1.2152135 0.4508075 -1.7295381 ]]
[[-1.0509509 -1.271087 1.9061071 0.61855525 0.58581835]
[ 2.080663 0.43406835 0.32372198 -0.71427256 0.04448809]
[-0.6438594 -1.1245041 -0.4723388 -0.8302859 -2.0056007 ]
[ 1.1778332 0.2977344 0.7516829 1.1387901 -0.71768486]
[-0.44642782 -0.6523012 -0.48157197 -0.8197472 0.3635474 ]]
[[-0.43357274 1.166849 -0.04528571 0.44322303 0.74193203]
[ 1.2332342 0.07857647 1.3399298 0.62153 1.835202 ]
[ 0.48021084 0.36239776 0.16630112 0.59010863 1.8134127 ]
[-1.1444335 1.2445287 -1.2320557 0.08095992 -0.1379302 ]
[-1.101756 -1.8099649 0.18504284 0.15212883 0.33380997]]
[[-0.68228734 -0.82357454 -0.744171 -0.04959428 -1.3200126 ]
[ 0.813062 1.0669035 -0.7924809 -0.0548021 0.8043163 ]
[ 1.6480085 -0.17134379 0.25517386 0.02731211 1.2226027 ]
[-1.9785942 -0.22399756 -0.6814836 1.2065881 -1.7922156 ]
[-0.34833568 -1.0567352 1.5795225 0.14899854 0.5924402 ]]
[[-1.057639 -1.1659449 -0.22045298 0.39324322 -1.3500952 ]
[-0.32044935 0.9534627 0.40809664 -1.0296333 -0.8129102 ]
[-0.13515176 -0.32676768 -0.9333701 0.35130095 -1.5411847 ]
[ 2.090785 0.3497966 0.27694222 0.78199005 -0.08591356]
[ 0.9621986 -2.3930101 -1.1035724 0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)
'''
You can use a bunch of nested tf.cond
.您可以使用一堆嵌套的tf.cond
。 If a condition is met, it will call either the true_fn
or the false_fn
.如果满足条件,它将调用true_fn
或false_fn
。 Since you have more than two functions, you can nest them for as many functions as you like.由于您有两个以上的函数,因此您可以根据需要将它们嵌套为多个函数。 For instance, I'm making functions that multiply the input by either 2, 3, 4 or 5, depending on the value of a random variable.例如,我正在制作将输入乘以 2、3、4 或 5 的函数,具体取决于随机变量的值。
import tensorflow as tf
x = 10
@tf.function
def mult_2():
tf.print(f'i was 2, returning {x} multiplied by 2')
return tf.multiply(x, 2)
@tf.function
def mult_3():
tf.print(f'i was 3, returning {x} multiplied by 3')
return tf.multiply(x, 3)
@tf.function
def mult_4():
tf.print(f'i was 4, returning {x} multiplied by 4')
return tf.multiply(x, 4)
@tf.function
def mult_5():
tf.print(f'i was 5, returning {x} multiplied by 5')
return tf.multiply(x, 5)
i = tf.random.uniform((), 1, 5, dtype=tf.int32)
tf.cond(i == 2, mult_2,
lambda: tf.cond(i == 3, mult_3,
lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>
Note that mult_5
will execute if none of the conditions are met.请注意,如果mult_5
任何条件, mult_5
将执行。
You can usetf.switch_case
like您可以使用tf.switch_case
类的
def func1(tensor):
return tensor * 1
def func2(tensor):
return tensor * 2
def func24(tensor):
return tensor * 24
class Lambda:
def __init__(self, func, arg):
self._func = func
self._arg = arg
def __call__(self):
return self._func(self._arg)
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, func24]
branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
output = tf.switch_case(
branch_index=branch_index,
branch_fns=[Lambda(func, tensor) for func in list_of_funcs],
)
return output
Decorator @tf.function
is needed only for entire function you wish to optimize that is apply
in this case.装饰器@tf.function
仅apply
于您希望优化的整个函数, apply
于这种情况。 If you use apply
inside tf.data.Dataset.map
the decorator is not needed at all.如果您在tf.data.Dataset.map
使用apply
,则根本不需要装饰器。
See this discussion to understand why we have to define class Lambda
here.请参阅此讨论以了解为什么我们必须在此处定义Lambda
类。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.