簡體   English   中英

如何使用 tf.function 從 TensorFlow 中的一組函數中隨機選擇

[英]How to randomly select from set of functions in TensorFlow using tf.function

我的問題是:在預處理期間,我想使用tf.data.Datasettf.function API 將從一組函數中隨機選擇的函數應用於數據集示例。

具體來說,我的數據是 3D 體積,我希望從一組 24 個預定義的旋轉函數中應用旋轉。 我想在tf.function編寫此代碼,因此這限制了numpy和列表索引等包的使用。

例如,我想做這樣的事情:

import tensorflow as tf

@tf.function
def func1(tensor):
    # Apply some rotation here
    ...

@tf.function
def func2(tensor):
    ...

...

@tf.function
def func24(tensor):
    ...


@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    
    return list_of_funcs[a](tensor)

但是我不能將list_of_funcs索引為TypeError: list indices must be integers or slices, not Tensor 此外,我無法將這些函數(AFAIK)收集到tf.Tensor並使用tf.gather

所以我的問題是:如何在tf.function從這些函數中合理而整齊地采樣?

也許嘗試使用tf.py_function ,其中:

將一個 python 函數包裝到一個 TensorFlow op 中,它會急切地執行它。

例如:

import tensorflow as tf
import random

@tf.function
def func1(tensor):
    print('func1')
    return tensor

@tf.function
def func2(tensor):
    print('func2')
    return tensor

@tf.function
def func3(tensor):
    print('func3')
    return tensor

@tf.function
def func4(tensor):
    print('func4')
    return tensor

@tf.function
def apply(tensor):
    dispatcher = {
        'func1': func1,
        'func2': func2,
        'func3': func3,
        'func4': func4
    }
    keys = list(dispatcher)
    
    def get_random_function_and_apply(t):
      return dispatcher[random.choice(keys)](t)

    y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
                       
    return y
    
print(apply(tf.random.normal((5, 5, 5))))

'''
func4
tf.Tensor(
[[[ 0.6041213  -2.054427    1.1755397  -0.62914884 -0.00978021]
  [ 0.06134182 -1.5529596  -0.3429052  -0.03199977 -1.1796658 ]
  [-0.65084136 -1.5009187  -0.43266404 -0.18494445  1.2958355 ]
  [-1.6614605  -0.7398612   1.5384725  -0.24926051 -0.5075399 ]
  [ 0.7781286  -0.4102168   1.2152135   0.4508075  -1.7295381 ]]

 [[-1.0509509  -1.271087    1.9061071   0.61855525  0.58581835]
  [ 2.080663    0.43406835  0.32372198 -0.71427256  0.04448809]
  [-0.6438594  -1.1245041  -0.4723388  -0.8302859  -2.0056007 ]
  [ 1.1778332   0.2977344   0.7516829   1.1387901  -0.71768486]
  [-0.44642782 -0.6523012  -0.48157197 -0.8197472   0.3635474 ]]

 [[-0.43357274  1.166849   -0.04528571  0.44322303  0.74193203]
  [ 1.2332342   0.07857647  1.3399298   0.62153     1.835202  ]
  [ 0.48021084  0.36239776  0.16630112  0.59010863  1.8134127 ]
  [-1.1444335   1.2445287  -1.2320557   0.08095992 -0.1379302 ]
  [-1.101756   -1.8099649   0.18504284  0.15212883  0.33380997]]

 [[-0.68228734 -0.82357454 -0.744171   -0.04959428 -1.3200126 ]
  [ 0.813062    1.0669035  -0.7924809  -0.0548021   0.8043163 ]
  [ 1.6480085  -0.17134379  0.25517386  0.02731211  1.2226027 ]
  [-1.9785942  -0.22399756 -0.6814836   1.2065881  -1.7922156 ]
  [-0.34833568 -1.0567352   1.5795225   0.14899854  0.5924402 ]]

 [[-1.057639   -1.1659449  -0.22045298  0.39324322 -1.3500952 ]
  [-0.32044935  0.9534627   0.40809664 -1.0296333  -0.8129102 ]
  [-0.13515176 -0.32676768 -0.9333701   0.35130095 -1.5411847 ]
  [ 2.090785    0.3497966   0.27694222  0.78199005 -0.08591356]
  [ 0.9621986  -2.3930101  -1.1035724   0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)

'''

您可以使用一堆嵌套的tf.cond 如果滿足條件,它將調用true_fnfalse_fn 由於您有兩個以上的函數,因此您可以根據需要將它們嵌套為多個函數。 例如,我正在制作將輸入乘以 2、3、4 或 5 的函數,具體取決於隨機變量的值。

import tensorflow as tf

x = 10


@tf.function
def mult_2():
    tf.print(f'i was 2, returning {x} multiplied by 2')
    return tf.multiply(x, 2)


@tf.function
def mult_3():
    tf.print(f'i was 3, returning {x} multiplied by 3')
    return tf.multiply(x, 3)


@tf.function
def mult_4():
    tf.print(f'i was 4, returning {x} multiplied by 4')
    return tf.multiply(x, 4)


@tf.function
def mult_5():
    tf.print(f'i was 5, returning {x} multiplied by 5')
    return tf.multiply(x, 5)


i = tf.random.uniform((), 1, 5, dtype=tf.int32)

tf.cond(i == 2, mult_2,
        lambda: tf.cond(i == 3, mult_3,
                        lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>

請注意,如果mult_5任何條件, mult_5將執行。

您可以使用tf.switch_case類的

def func1(tensor):
    return tensor * 1

def func2(tensor):
    return tensor * 2

def func24(tensor):
    return tensor * 24

class Lambda:
    def __init__(self, func, arg):
        self._func = func
        self._arg = arg
        
    def __call__(self):
        return self._func(self._arg)

@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, func24]

    branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
    output = tf.switch_case(
        branch_index=branch_index, 
        branch_fns=[Lambda(func, tensor) for func in list_of_funcs], 
    )
    
    return output

裝飾器@tf.functionapply於您希望優化的整個函數, apply於這種情況。 如果您在tf.data.Dataset.map使用apply ,則根本不需要裝飾器。

請參閱此討論以了解為什么我們必須在此處定義Lambda類。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM