簡體   English   中英

有沒有辦法通過亞麻中的 self.put_variable 方法跟蹤畢業生?

[英]is there a way to trace grads through self.put_variable method in flax?

我想通過 self.put_variable 追蹤畢業生。 有沒有辦法讓這成為可能? 或者另一種更新提供給被跟蹤模塊的參數的方法?

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 


class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

      
        self.put_variable("params","b",(x@W+b).reshape(5,))  
    
        return jnp.sum(x+b)


if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)
    #x,param = module.apply(param,x,mutable=["params"])
    #print(param)
    print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))

我的輸出畢業生是:

FrozenDict({
    params: {
        W: DeviceArray([[0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },

什么表明它沒有通過 self.variable_put 方法跟蹤 grads,因為 grads 到 W 都是零,而 b 顯然依賴於 W。

您的模型的輸出是jnp.sum(x + b) ,它不依賴於W ,這反過來意味着相對於W的梯度應該為零。 考慮到這一點,您在上面顯示的輸出看起來是正確的。

編輯:聽起來您希望在變量中使用的x@W+b的結果反映在 return 語句中使用的b值中; 也許你想要這樣的東西?

    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

        b = x@W+b
        self.put_variable("params","b",b.reshape(5,)) 
    
        return jnp.sum(x+b)

也就是說,我不清楚你的最終目標是什么,鑒於你問的是這樣一個不常見的構造,我懷疑這可能是一個XY 問題 也許您可以編輯您的問題,以更多地說明您要完成的工作。

就像@jakevdp 指出的那樣,上面的測試是不正確的,因為 b 仍然與前一個 b 相關聯。
https://github.com/google/flax/discussions/2215說 self.put_variable 被追蹤。

使用以下代碼測試是否確實如此:

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 

class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

        b = x@W+b #update the b variable else it is still tied to the previous one.
        self.put_variable("params","b",(b).reshape(5,))  
     
        return jnp.sum(x+b)

def test_update(param,x):
    _, param = module.apply(param,x,mutable=["params"])
    return jnp.sum(param["params"]["b"]+x),param 

if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)

    print(grad(test_update,has_aux=True)(param,x))

輸出:

FrozenDict({
    params: {
        W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                       0.00599653],
                     [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                       0.0097266 ],
                     [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                      -0.00616944],
                     [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                       0.00252594],
                     [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                       0.00956188]], dtype=float32),
        b: DeviceArray([-0.00905032, -0.00574646,  0.01621638, -0.01165553,
                     -0.0285466 ], dtype=float32),
    },
})
(FrozenDict({
    params: {
        W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
                      -1.1489547 ],
                     [-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
                      -2.0069852 ],
                     [ 0.98777294,  0.98777294,  0.98777294,  0.98777294,
                       0.98777294],
                     [ 0.9311977 ,  0.9311977 ,  0.9311977 ,  0.9311977 ,
                       0.9311977 ],
                     [-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
                      -0.2883922 ]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },
}), FrozenDict({
    params: {
        W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                       0.00599653],
                     [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                       0.0097266 ],
                     [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                      -0.00616944],
                     [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                       0.00252594],
                     [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                       0.00956188]], dtype=float32),
        b: DeviceArray([-0.01861148, -0.00523183,  0.03968921, -0.01952654,
                     -0.06145691], dtype=float32),
    },
}))

第一個 FrozenDict 是原始參數。
第二個 FrozenDict 是畢業生,顯然是通過 self.put_variable 追蹤的。
最后一個 FrozenDict 是參數,在這里我們可以看到 b 被正確更新了。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM