[英]is there a way to trace grads through self.put_variable method in flax?
我想通過 self.put_variable 追蹤畢業生。 有沒有辦法讓這成為可能? 或者另一種更新提供給被跟蹤模塊的參數的方法?
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
self.put_variable("params","b",(x@W+b).reshape(5,))
return jnp.sum(x+b)
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))
我的輸出畢業生是:
FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
什么表明它沒有通過 self.variable_put 方法跟蹤 grads,因為 grads 到 W 都是零,而 b 顯然依賴於 W。
您的模型的輸出是jnp.sum(x + b)
,它不依賴於W
,這反過來意味着相對於W
的梯度應該為零。 考慮到這一點,您在上面顯示的輸出看起來是正確的。
編輯:聽起來您希望在變量中使用的x@W+b
的結果反映在 return 語句中使用的b
值中; 也許你想要這樣的東西?
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b
self.put_variable("params","b",b.reshape(5,))
return jnp.sum(x+b)
也就是說,我不清楚你的最終目標是什么,鑒於你問的是這樣一個不常見的構造,我懷疑這可能是一個XY 問題。 也許您可以編輯您的問題,以更多地說明您要完成的工作。
就像@jakevdp 指出的那樣,上面的測試是不正確的,因為 b 仍然與前一個 b 相關聯。
https://github.com/google/flax/discussions/2215說 self.put_variable 被追蹤。
使用以下代碼測試是否確實如此:
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b #update the b variable else it is still tied to the previous one.
self.put_variable("params","b",(b).reshape(5,))
return jnp.sum(x+b)
def test_update(param,x):
_, param = module.apply(param,x,mutable=["params"])
return jnp.sum(param["params"]["b"]+x),param
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
print(grad(test_update,has_aux=True)(param,x))
輸出:
FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.00905032, -0.00574646, 0.01621638, -0.01165553,
-0.0285466 ], dtype=float32),
},
})
(FrozenDict({
params: {
W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
-1.1489547 ],
[-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
-2.0069852 ],
[ 0.98777294, 0.98777294, 0.98777294, 0.98777294,
0.98777294],
[ 0.9311977 , 0.9311977 , 0.9311977 , 0.9311977 ,
0.9311977 ],
[-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
-0.2883922 ]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
}), FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.01861148, -0.00523183, 0.03968921, -0.01952654,
-0.06145691], dtype=float32),
},
}))
第一個 FrozenDict 是原始參數。
第二個 FrozenDict 是畢業生,顯然是通過 self.put_variable 追蹤的。
最后一個 FrozenDict 是參數,在這里我們可以看到 b 被正確更新了。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.