[英]How to randomly select from set of functions in TensorFlow using tf.function
我的問題是:在預處理期間,我想使用tf.data.Dataset
和tf.function
API 將從一組函數中隨機選擇的函數應用於數據集示例。
具體來說,我的數據是 3D 體積,我希望從一組 24 個預定義的旋轉函數中應用旋轉。 我想在tf.function
編寫此代碼,因此這限制了numpy
和列表索引等包的使用。
例如,我想做這樣的事情:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
但是我不能將list_of_funcs
索引為TypeError: list indices must be integers or slices, not Tensor
。 此外,我無法將這些函數(AFAIK)收集到tf.Tensor
並使用tf.gather
。
所以我的問題是:如何在tf.function
從這些函數中合理而整齊地采樣?
也許嘗試使用tf.py_function ,其中:
將一個 python 函數包裝到一個 TensorFlow op 中,它會急切地執行它。
例如:
import tensorflow as tf
import random
@tf.function
def func1(tensor):
print('func1')
return tensor
@tf.function
def func2(tensor):
print('func2')
return tensor
@tf.function
def func3(tensor):
print('func3')
return tensor
@tf.function
def func4(tensor):
print('func4')
return tensor
@tf.function
def apply(tensor):
dispatcher = {
'func1': func1,
'func2': func2,
'func3': func3,
'func4': func4
}
keys = list(dispatcher)
def get_random_function_and_apply(t):
return dispatcher[random.choice(keys)](t)
y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
return y
print(apply(tf.random.normal((5, 5, 5))))
'''
func4
tf.Tensor(
[[[ 0.6041213 -2.054427 1.1755397 -0.62914884 -0.00978021]
[ 0.06134182 -1.5529596 -0.3429052 -0.03199977 -1.1796658 ]
[-0.65084136 -1.5009187 -0.43266404 -0.18494445 1.2958355 ]
[-1.6614605 -0.7398612 1.5384725 -0.24926051 -0.5075399 ]
[ 0.7781286 -0.4102168 1.2152135 0.4508075 -1.7295381 ]]
[[-1.0509509 -1.271087 1.9061071 0.61855525 0.58581835]
[ 2.080663 0.43406835 0.32372198 -0.71427256 0.04448809]
[-0.6438594 -1.1245041 -0.4723388 -0.8302859 -2.0056007 ]
[ 1.1778332 0.2977344 0.7516829 1.1387901 -0.71768486]
[-0.44642782 -0.6523012 -0.48157197 -0.8197472 0.3635474 ]]
[[-0.43357274 1.166849 -0.04528571 0.44322303 0.74193203]
[ 1.2332342 0.07857647 1.3399298 0.62153 1.835202 ]
[ 0.48021084 0.36239776 0.16630112 0.59010863 1.8134127 ]
[-1.1444335 1.2445287 -1.2320557 0.08095992 -0.1379302 ]
[-1.101756 -1.8099649 0.18504284 0.15212883 0.33380997]]
[[-0.68228734 -0.82357454 -0.744171 -0.04959428 -1.3200126 ]
[ 0.813062 1.0669035 -0.7924809 -0.0548021 0.8043163 ]
[ 1.6480085 -0.17134379 0.25517386 0.02731211 1.2226027 ]
[-1.9785942 -0.22399756 -0.6814836 1.2065881 -1.7922156 ]
[-0.34833568 -1.0567352 1.5795225 0.14899854 0.5924402 ]]
[[-1.057639 -1.1659449 -0.22045298 0.39324322 -1.3500952 ]
[-0.32044935 0.9534627 0.40809664 -1.0296333 -0.8129102 ]
[-0.13515176 -0.32676768 -0.9333701 0.35130095 -1.5411847 ]
[ 2.090785 0.3497966 0.27694222 0.78199005 -0.08591356]
[ 0.9621986 -2.3930101 -1.1035724 0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)
'''
您可以使用一堆嵌套的tf.cond
。 如果滿足條件,它將調用true_fn
或false_fn
。 由於您有兩個以上的函數,因此您可以根據需要將它們嵌套為多個函數。 例如,我正在制作將輸入乘以 2、3、4 或 5 的函數,具體取決於隨機變量的值。
import tensorflow as tf
x = 10
@tf.function
def mult_2():
tf.print(f'i was 2, returning {x} multiplied by 2')
return tf.multiply(x, 2)
@tf.function
def mult_3():
tf.print(f'i was 3, returning {x} multiplied by 3')
return tf.multiply(x, 3)
@tf.function
def mult_4():
tf.print(f'i was 4, returning {x} multiplied by 4')
return tf.multiply(x, 4)
@tf.function
def mult_5():
tf.print(f'i was 5, returning {x} multiplied by 5')
return tf.multiply(x, 5)
i = tf.random.uniform((), 1, 5, dtype=tf.int32)
tf.cond(i == 2, mult_2,
lambda: tf.cond(i == 3, mult_3,
lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>
請注意,如果mult_5
任何條件, mult_5
將執行。
您可以使用tf.switch_case
類的
def func1(tensor):
return tensor * 1
def func2(tensor):
return tensor * 2
def func24(tensor):
return tensor * 24
class Lambda:
def __init__(self, func, arg):
self._func = func
self._arg = arg
def __call__(self):
return self._func(self._arg)
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, func24]
branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
output = tf.switch_case(
branch_index=branch_index,
branch_fns=[Lambda(func, tensor) for func in list_of_funcs],
)
return output
裝飾器@tf.function
僅apply
於您希望優化的整個函數, apply
於這種情況。 如果您在tf.data.Dataset.map
使用apply
,則根本不需要裝飾器。
請參閱此討論以了解為什么我們必須在此處定義Lambda
類。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.