简体   繁体   English

JAX - jitting 函数:参数与“全局”变量

[英]JAX - jitting functions: parameters vs "global" variables

I've have the following doubt about Jax.我对 Jax 有以下疑问。 I'll use an example from the official optax docs to illustrate it:我将使用官方optax 文档中的一个示例来说明它:

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

In this example, the function step uses the variable optimizer despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation is not a supported type). In this example, the function step uses the variable optimizer despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation is not a supported type). However, the same function uses other variables that are instead passed as parameters (ie, params, opt_state, batch, labels ).但是,同样的 function 使用其他变量作为参数传递(即params, opt_state, batch, labels )。 I understand that jax functions needs to be pure in order to be jitted, but what about input (read-only) variables.我知道 jax 函数需要是纯的才能被 jitted,但是输入(只读)变量呢? Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the step function scope? Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the step function scope? What if this variable is not constant but modified between separate step calls?如果这个变量不是常数而是在不同的step调用之间修改了怎么办? Are they treated like static arguments if accessed directly?如果直接访问,它们是否像 static arguments 一样对待? Or are they simply jitted away and so modifications of such parameters will not be considered?或者它们只是被忽略了,因此不会考虑对这些参数进行修改?

To be more specific, let's look at the following example:更具体地说,让我们看以下示例:

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_learning_rate # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    extra_learning_rate = 0.01 # does this affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels)

  return params

vs对比

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels, extra_lr):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_lr # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
    extra_learning_rate = 0.01 # does this now affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)

  return params

From my limited experiments, they perform exactly the same (ie, the second step call correctly uses the new learning rates in both cases) and no 're-jitting' happens, however I'd like to know if there's any standard practice/rules I need to be aware of.从我有限的实验来看,它们的表现完全相同(即, step调用在两种情况下都正确使用了新的学习率)并且没有发生“重新抖动”,但是我想知道是否有任何标准做法/规则我需要注意。 I'm writing a library where performance is fundamental and I don't want to miss some jit optimizations because I'm doing things wrong.我正在编写一个以性能为基础的库,我不想错过一些 jit 优化,因为我做错了。

During JIT tracing, JAX treats global values as implicit arguments to the function being traced.在 JIT 跟踪期间,JAX 将全局值视为隐含的 arguments 到被跟踪的 function。 You can see this reflected in the jaxpr representing the function.您可以在代表 function 的jaxpr中看到这一点。

Here are two simple functions that return equivalent results, one with implicit arguments and one with explicit:下面是两个返回等效结果的简单函数,一个是隐式 arguments,一个是显式:

import jax
import jax.numpy as jnp

def f_explicit(a, b):
  return a + b

def f_implicit(b):
  return a_global + b

a_global = jnp.arange(5.0)
b = jnp.ones(5)

print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }

print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }

Notice the only difference in the two jaxprs is that in f_implicit , the a variable comes before the semicolon: this is the way that jaxpr representations indicate the argument is passed via closure rather than via an explicit argument.请注意,两个 jaxpr 的唯一区别在于,在f_implicit中, a变量位于分号之前:这是jaxpr表示表示参数通过闭包而不是通过显式参数传递的方式。 But the computation generated by these two functions will be identical.但是这两个函数生成的计算是相同的。

That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via static_argnums or static_argnames within jax.jit ): That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via static_argnums or static_argnames within jax.jit ):

a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }

Notice in the jaxpr representation the constant value is inserted directly as an argument to the add operation.请注意,在 jaxpr 表示中,常量值直接作为参数插入到add操作中。 The explicit way to to get the same result for a JIT-compiled function would look something like this:为 JIT 编译的 function 获得相同结果的显式方法如下所示:

from functools import partial

@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
  return a + b

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM