[英]JAX - jitting functions: parameters vs "global" variables
I've have the following doubt about Jax.我对 Jax 有以下疑问。 I'll use an example from the official optax docs to illustrate it:
我将使用官方optax 文档中的一个示例来说明它:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
In this example, the function step
uses the variable optimizer
despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation
is not a supported type). In this example, the function
step
uses the variable optimizer
despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation
is not a supported type). However, the same function uses other variables that are instead passed as parameters (ie, params, opt_state, batch, labels
).但是,同样的 function 使用其他变量作为参数传递(即
params, opt_state, batch, labels
)。 I understand that jax functions needs to be pure in order to be jitted, but what about input (read-only) variables.我知道 jax 函数需要是纯的才能被 jitted,但是输入(只读)变量呢? Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the
step
function scope? Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the
step
function scope? What if this variable is not constant but modified between separate step
calls?如果这个变量不是常数而是在不同的
step
调用之间修改了怎么办? Are they treated like static arguments if accessed directly?如果直接访问,它们是否像 static arguments 一样对待? Or are they simply jitted away and so modifications of such parameters will not be considered?
或者它们只是被忽略了,因此不会考虑对这些参数进行修改?
To be more specific, let's look at the following example:更具体地说,让我们看以下示例:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_learning_rate # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels)
extra_learning_rate = 0.01 # does this affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels)
return params
vs对比
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels, extra_lr):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_lr # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
extra_learning_rate = 0.01 # does this now affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
return params
From my limited experiments, they perform exactly the same (ie, the second step
call correctly uses the new learning rates in both cases) and no 're-jitting' happens, however I'd like to know if there's any standard practice/rules I need to be aware of.从我有限的实验来看,它们的表现完全相同(即,
step
调用在两种情况下都正确使用了新的学习率)并且没有发生“重新抖动”,但是我想知道是否有任何标准做法/规则我需要注意。 I'm writing a library where performance is fundamental and I don't want to miss some jit optimizations because I'm doing things wrong.我正在编写一个以性能为基础的库,我不想错过一些 jit 优化,因为我做错了。
During JIT tracing, JAX treats global values as implicit arguments to the function being traced.在 JIT 跟踪期间,JAX 将全局值视为隐含的 arguments 到被跟踪的 function。 You can see this reflected in the jaxpr representing the function.
您可以在代表 function 的jaxpr中看到这一点。
Here are two simple functions that return equivalent results, one with implicit arguments and one with explicit:下面是两个返回等效结果的简单函数,一个是隐式 arguments,一个是显式:
import jax
import jax.numpy as jnp
def f_explicit(a, b):
return a + b
def f_implicit(b):
return a_global + b
a_global = jnp.arange(5.0)
b = jnp.ones(5)
print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }
print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }
Notice the only difference in the two jaxprs is that in f_implicit
, the a
variable comes before the semicolon: this is the way that jaxpr
representations indicate the argument is passed via closure rather than via an explicit argument.请注意,两个 jaxpr 的唯一区别在于,在
f_implicit
中, a
变量位于分号之前:这是jaxpr
表示表示参数通过闭包而不是通过显式参数传递的方式。 But the computation generated by these two functions will be identical.但是这两个函数生成的计算是相同的。
That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via static_argnums
or static_argnames
within jax.jit
): That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via
static_argnums
or static_argnames
within jax.jit
):
a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }
Notice in the jaxpr representation the constant value is inserted directly as an argument to the add
operation.请注意,在 jaxpr 表示中,常量值直接作为参数插入到
add
操作中。 The explicit way to to get the same result for a JIT-compiled function would look something like this:为 JIT 编译的 function 获得相同结果的显式方法如下所示:
from functools import partial
@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
return a + b
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.