简体   繁体   English

如何使用 tf.function 从 TensorFlow 中的一组函数中随机选择

[英]How to randomly select from set of functions in TensorFlow using tf.function

My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset and tf.function API.我的问题是:在预处理期间,我想使用tf.data.Datasettf.function API 将从一组函数中随机选择的函数应用于数据集示例。

Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions.具体来说,我的数据是 3D 体积,我希望从一组 24 个预定义的旋转函数中应用旋转。 I would like to write this code within a tf.function so this limits the use of packages like numpy and list indexing.我想在tf.function编写此代码,因此这限制了numpy和列表索引等包的使用。

For example, I would like to do something like this:例如,我想做这样的事情:

import tensorflow as tf

@tf.function
def func1(tensor):
    # Apply some rotation here
    ...

@tf.function
def func2(tensor):
    ...

...

@tf.function
def func24(tensor):
    ...


@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    
    return list_of_funcs[a](tensor)

However I cannot index the list_of_funcs as TypeError: list indices must be integers or slices, not Tensor .但是我不能将list_of_funcs索引为TypeError: list indices must be integers or slices, not Tensor Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor and use tf.gather .此外,我无法将这些函数(AFAIK)收集到tf.Tensor并使用tf.gather

So my question: how can I reasonably and neatly sample from these functions in a tf.function ?所以我的问题是:如何在tf.function从这些函数中合理而整齐地采样?

Maybe try using tf.py_function , which:也许尝试使用tf.py_function ,其中:

Wraps a python function into a TensorFlow op that executes it eagerly.将一个 python 函数包装到一个 TensorFlow op 中,它会急切地执行它。

For example:例如:

import tensorflow as tf
import random

@tf.function
def func1(tensor):
    print('func1')
    return tensor

@tf.function
def func2(tensor):
    print('func2')
    return tensor

@tf.function
def func3(tensor):
    print('func3')
    return tensor

@tf.function
def func4(tensor):
    print('func4')
    return tensor

@tf.function
def apply(tensor):
    dispatcher = {
        'func1': func1,
        'func2': func2,
        'func3': func3,
        'func4': func4
    }
    keys = list(dispatcher)
    
    def get_random_function_and_apply(t):
      return dispatcher[random.choice(keys)](t)

    y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
                       
    return y
    
print(apply(tf.random.normal((5, 5, 5))))

'''
func4
tf.Tensor(
[[[ 0.6041213  -2.054427    1.1755397  -0.62914884 -0.00978021]
  [ 0.06134182 -1.5529596  -0.3429052  -0.03199977 -1.1796658 ]
  [-0.65084136 -1.5009187  -0.43266404 -0.18494445  1.2958355 ]
  [-1.6614605  -0.7398612   1.5384725  -0.24926051 -0.5075399 ]
  [ 0.7781286  -0.4102168   1.2152135   0.4508075  -1.7295381 ]]

 [[-1.0509509  -1.271087    1.9061071   0.61855525  0.58581835]
  [ 2.080663    0.43406835  0.32372198 -0.71427256  0.04448809]
  [-0.6438594  -1.1245041  -0.4723388  -0.8302859  -2.0056007 ]
  [ 1.1778332   0.2977344   0.7516829   1.1387901  -0.71768486]
  [-0.44642782 -0.6523012  -0.48157197 -0.8197472   0.3635474 ]]

 [[-0.43357274  1.166849   -0.04528571  0.44322303  0.74193203]
  [ 1.2332342   0.07857647  1.3399298   0.62153     1.835202  ]
  [ 0.48021084  0.36239776  0.16630112  0.59010863  1.8134127 ]
  [-1.1444335   1.2445287  -1.2320557   0.08095992 -0.1379302 ]
  [-1.101756   -1.8099649   0.18504284  0.15212883  0.33380997]]

 [[-0.68228734 -0.82357454 -0.744171   -0.04959428 -1.3200126 ]
  [ 0.813062    1.0669035  -0.7924809  -0.0548021   0.8043163 ]
  [ 1.6480085  -0.17134379  0.25517386  0.02731211  1.2226027 ]
  [-1.9785942  -0.22399756 -0.6814836   1.2065881  -1.7922156 ]
  [-0.34833568 -1.0567352   1.5795225   0.14899854  0.5924402 ]]

 [[-1.057639   -1.1659449  -0.22045298  0.39324322 -1.3500952 ]
  [-0.32044935  0.9534627   0.40809664 -1.0296333  -0.8129102 ]
  [-0.13515176 -0.32676768 -0.9333701   0.35130095 -1.5411847 ]
  [ 2.090785    0.3497966   0.27694222  0.78199005 -0.08591356]
  [ 0.9621986  -2.3930101  -1.1035724   0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)

'''

You can use a bunch of nested tf.cond .您可以使用一堆嵌套的tf.cond If a condition is met, it will call either the true_fn or the false_fn .如果满足条件,它将调用true_fnfalse_fn Since you have more than two functions, you can nest them for as many functions as you like.由于您有两个以上的函数,因此您可以根据需要将它们嵌套为多个函数。 For instance, I'm making functions that multiply the input by either 2, 3, 4 or 5, depending on the value of a random variable.例如,我正在制作将输入乘以 2、3、4 或 5 的函数,具体取决于随机变量的值。

import tensorflow as tf

x = 10


@tf.function
def mult_2():
    tf.print(f'i was 2, returning {x} multiplied by 2')
    return tf.multiply(x, 2)


@tf.function
def mult_3():
    tf.print(f'i was 3, returning {x} multiplied by 3')
    return tf.multiply(x, 3)


@tf.function
def mult_4():
    tf.print(f'i was 4, returning {x} multiplied by 4')
    return tf.multiply(x, 4)


@tf.function
def mult_5():
    tf.print(f'i was 5, returning {x} multiplied by 5')
    return tf.multiply(x, 5)


i = tf.random.uniform((), 1, 5, dtype=tf.int32)

tf.cond(i == 2, mult_2,
        lambda: tf.cond(i == 3, mult_3,
                        lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>

Note that mult_5 will execute if none of the conditions are met.请注意,如果mult_5任何条件, mult_5将执行。

You can usetf.switch_case like您可以使用tf.switch_case类的

def func1(tensor):
    return tensor * 1

def func2(tensor):
    return tensor * 2

def func24(tensor):
    return tensor * 24

class Lambda:
    def __init__(self, func, arg):
        self._func = func
        self._arg = arg
        
    def __call__(self):
        return self._func(self._arg)

@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, func24]

    branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
    output = tf.switch_case(
        branch_index=branch_index, 
        branch_fns=[Lambda(func, tensor) for func in list_of_funcs], 
    )
    
    return output

Decorator @tf.function is needed only for entire function you wish to optimize that is apply in this case.装饰器@tf.functionapply于您希望优化的整个函数, apply于这种情况。 If you use apply inside tf.data.Dataset.map the decorator is not needed at all.如果您在tf.data.Dataset.map使用apply ,则根本不需要装饰器。

See this discussion to understand why we have to define class Lambda here.请参阅此讨论以了解为什么我们必须在此处定义Lambda类。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM