简体   繁体   中英

jax return 0 if condition, continue if not in a jitted function

I want to replicate this behaviour in a jitted function (the function is an example):

def function(x,y):
   if y==0:
      return x
   return x+1

Using jax.lax.cond it can be obtained with:

@jax.jit
def function(x,y):
   return jax.lax.cond(y==0, lambda x: x, lambda x: x+1, x)

This is simple as long as whatever needs to be done is simple when y!=0 (in this case, just add 1 to x ). However, if that's complex, or there are more conditions of this sort, coding gets more convoluted.

Is there a way to get the behavior "if y==0 , return x , if not, just keep running the function. jax.lax.cond requires a new function for every condition that is applied.

For example, this starts to become convoluted.

def function(x,y):
    if y==0:
       return x
    if y>0:
       return x-y
    if y<0:
       return x+y

This starts to be messy:

@jax.jit
def function(x,y):
    jax.lax.cond(y==0, 
             lambda x,y: x, 
             lambda x,y: jax.lax.cond(x>0, lambda x,y:x-y, lambda x,y: x+y, x,y),
             x,y)

Is there a better way?

In short, no, there's no way to return early from a Python function conditioned on traced values. The pattern I typically see to avoid messy nesting is to encapsulate logic in helper functions and call them via lax.cond .

Alternatively, if you are branching based on multiple conditions, you may be able to better express the logic in terms of lax.switch ; for example:

@jax.jit
def function(x, y):
  branches = [lambda: x, lambda: x-y, lambda: x+y]
  conditions = jnp.array([y == 0, x > 0, True])
  return lax.switch(jnp.argmax(conditions), branches)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM