简体   繁体   English

如何在 Jax 中获得包含 argmax 的损失函数的梯度?

[英]How do you get the gradients of a loss function containing argmax in Jax?

I am facing this issue where I get zero gradients after using argmax in a loss function.我面临这个问题,在损失函数中使用 argmax 后梯度为零。 I have created a minimal example:我创建了一个最小的例子:

import haiku as hk
import jax.numpy as jnp
import jax.random
import optax
import chex

hidden_dim = 64
input_shape = 12
num_classes = 2

class MLP(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)


        xavier_constant_1 = jnp.sqrt(6/(input_shape + 256))
        xavier_constant_2 = jnp.sqrt(6/(256 + 256))
        xavier_constant_3 = jnp.sqrt(6/(1 + 256))
        self.seq = hk.Sequential([
            hk.Linear(hidden_dim, w_init=hk.initializers.RandomUniform(-xavier_constant_1, xavier_constant_1), b_init=hk.initializers.Constant(0.)),
            hk.Linear(hidden_dim, w_init=hk.initializers.RandomUniform(-xavier_constant_2, xavier_constant_2), b_init=hk.initializers.Constant(0.)),
            hk.Linear(num_classes, w_init=hk.initializers.RandomUniform(-xavier_constant_3, xavier_constant_3), b_init=hk.initializers.Constant(0.))
        ])

    def __call__(self, x: chex.Array):
        out = x.reshape((x.shape[0], -1))
        return self.seq(out)


def train_simulated():
    def mlp_fn(x):
        mlp = MLP('test_mlp')
        return mlp(x)
    mlp = hk.transform(mlp_fn)
    init, apply = hk.without_apply_rng(mlp)

    k1 = jax.random.PRNGKey(0)
    k2 = jax.random.PRNGKey(1)
    k3 = jax.random.PRNGKey(2)
    k4 = jax.random.PRNGKey(3)
    params = init(k1, jnp.ones((10, 12)))

    def loss_fn(parameters, x: chex.Array, y: chex.Array):
        y_hat = apply(parameters, x)
        preds = jnp.argmax(y_hat, axis=1)
        return ((preds.reshape(-1, 1) - y) ** 2).sum()

    loss_value_grad = jax.value_and_grad(loss_fn)
    v, g = loss_value_grad(params, jax.random.uniform(k2, (10, 12)), (jax.random.uniform(k3, (10, 1)) > 0.5).astype(float))
    print(g)

if __name__ == '__main__':
    train_simulated()

The output of the code is the gradients of the loss function for the parameters.代码的输出是参数损失函数的梯度。 However, all of the gradients are zero.然而,所有的梯度都是零。 This is not expected because the labels and the inputs are generated randomly.这不是预期的,因为标签和输入是随机生成的。

When you're using a sorting-based computation like argmax , often zero is the correct gradient.当您使用像argmax这样的基于排序的计算时,通常零是正确的梯度。 For more discussion of this, see FAQ: Why are gradients zero for functions based on sort order?有关此问题的更多讨论,请参阅常见问题解答:为什么基于排序顺序的函数的梯度为零? in the JAX documentation.在 JAX 文档中。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM