简体   繁体   English

使用 JAX 进行梯度累积

[英]Gradient Accumulation with JAX

I made a simple script to try to do gradient accumulation with JAX.我制作了一个简单的脚本来尝试使用 JAX 进行梯度累积。 The idea is to have large batch size (eg 64) that are split in small chunks (eg 4) that fit in the GPU's memory. For each chunck, the resulting gradient, stored in a pytree, is added to the current batch gradient.这个想法是将大批量大小(例如 64)分成适合 GPU 的 memory 的小块(例如 4)。对于每个块,将存储在 pytree 中的结果梯度添加到当前批梯度中。 The update is done only when all chunks of the large batch are computed.只有在计算了大批量的所有块时才会进行更新。 In this particular example, we simply try to fit random 512-dimensional vectors to random booleans with a linear layer.在这个特定的例子中,我们只是尝试将随机的 512 维向量拟合到具有线性层的随机布尔值。 Here is the script:这是脚本:

import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass

@dataclass
class Jax_model:
    init_fun: Callable
    apply_fun: Callable


def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):

    def init_fun(key):
        key, sub_key1, sub_key2 = jax.random.split(key, 3)
        params = {
            'I': init_kernel(sub_key1, (input_size, output_size) ),
            'I_b': init_bias(sub_key2, (1,output_size) ),
        }
        return params

    def apply_fun(params, inputs):
        I, I_b, = params['I'], params['I_b']
        logits = inputs @ I + I_b
        return logits

    return Jax_model(init_fun, apply_fun)


def divide_pytree(pytree, div):
    for pt in jax.tree_util.tree_leaves(pytree):
        pt = pt / div
    return pytree


def add_pytrees(pytree1, pytree2):
    for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
        pt1 = pt1 + pt2
    return pytree1


rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50

model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)

@jit
def update(i, current_opt_state, current_batch):
    N = current_batch[0].shape[0]
    K = accumulation_size
    num_gradients = N//K
    accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
    value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
    value = value / num_gradients
    grads = divide_pytree(grads, num_gradients)
    for k in range(K,N,K):
        accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
        new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
        value = value + (new_value / num_gradients)
        grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
    return opt_update(i, grads, current_opt_state), value

def loss_func(current_params, current_batch):
    inputs, labels = current_batch
    predictions = model.apply_fun(current_params, inputs)
    loss = jnp.square(labels-predictions).sum()
    return loss

for i in range(n_iter):
    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
    inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
    labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
    batch = inputs, labels
    opt_state, batch_loss = update(i, opt_state, batch)
    print(i, batch_loss)

I have doubts about the divide_pytree and add_pytrees .我对divide_pytreeadd_pytrees有疑问。 Does it actually modify the current batch gradient or am I missing something?它实际上修改了当前的批次梯度还是我遗漏了什么? Moreover, do you see any speed issue with this code?此外,您是否看到此代码有任何速度问题? In particular, should I use the jax.lax.fori_loop in place of the traditional python for loop?特别是,我应该使用jax.lax.fori_loop代替传统的 python for 循环吗?

Related links:相关链接:

Regarding the pytree computations: as written your functions are returning the input unmodified.关于 pytree 计算:如所写,您的函数返回未修改的输入。 The better approach for this is to use jax.tree_util.tree_map ;更好的方法是使用jax.tree_util.tree_map for example:例如:

from jax.tree_util import tree_map

def divide_pytree(pytree, div):
  return tree_map(lambda pt: pt / div, pytree)

def add_pytrees(pytree1, pytree2):
  return tree_map(lambda pt1, pt2: pt1 + pt2, pytree1, pytree2)

Regarding performance: anything in the for loop will be flattened when JIT-compiled, with one repeated copy of all XLA instructions per iteration of the loop.关于性能: for循环中的任何内容在 JIT 编译时都会被展平,每次循环迭代都会重复复制所有 XLA 指令。 If you have 5 iterations, that's not really an issue.如果您有 5 次迭代,那不是真正的问题。 If you have 5000, that would significantly slow down compilation times (because XLA needs to analyze & optimize 5000 explicit copies of the instructions in the loop).如果您有 5000 个,那将显着减慢编译时间(因为 XLA 需要分析和优化循环中指令的 5000 个显式副本)。

fori_loop can help, but does not lead to optimal code, particularly when running on CPU and GPU. fori_loop可以提供帮助,但不会导致最佳代码,尤其是在 CPU 和 GPU 上运行时。

Better would be to use broadcasted or vmapped operations where possible to express the logic of the loops without explicit looping.更好的做法是在可能的情况下使用广播或 vmapped 操作来表达循环逻辑而无需显式循环。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM