cost 98 ms
從 flax 模型的 init 調用中獲取不正確的輸出

[英]Getting incorrect output from the flax model's init call

我正在嘗試使用亞麻創建一個簡單的神經網絡,如下所示。 但是,我收到的params frozen dict 作為model.init的輸出是空的,而不是具有神經網絡的參數。 此外, type(predictions)是flax.linen.combinators.Sequential對象,而不是Dev ...

2022-12-04 12:20:21   1   26    jax / flax  
Jax - 對一批數據類進行 vmap

[英]Jax - vmap over batch of dataclasses

在 JAX 中,我希望在固定長度的數據類列表上 vmap 一個函數,例如: 上面的示例失敗,因為無法創建自定義對象的 jnp.Array,並且 JAX 不允許在 Python 列表上進行 vmapping。 我看到的唯一剩余選項是轉換數據類以表示一批參數,如下所示: 最好使用結構容器(每個結構代表 ...

如何將(亞麻)GRUCell 的隱藏 state(攜帶)初始化為可學習參數(例如使用 model.init)

[英]How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)

我使用 Flax 在 Jax 中創建 GRU model 並使用 model.init 初始化 model 參數,如下所示:import jax.numpy as np from jax import random import flax.linen as nn from jax.nn impor ...

我正在嘗試將 JAX Tracer object 傳遞給需要 numpy 數組的模塊 - 請解決需要

[英]I am trying to pass a JAX Tracer object to a module that requires a numpy array - work around needed please

我是賈克斯的新手。 我正在使用 Jax 和 Flax 實現變分自動編碼器 (VAE)。 在訓練期間,我采樣了一個潛在代碼(來自編碼器推斷的分布,我使用 flax.linen.nn 模塊的組合來實現)。 至關重要的是,除了將此代碼通過解碼器(作為 VAE 的標准)傳遞之外,我還將代碼傳遞給外部 fun ...

2022-08-23 20:24:39   1   30    jax / flax  
有沒有辦法通過亞麻中的 self.put_variable 方法跟蹤畢業生?

[英]is there a way to trace grads through self.put_variable method in flax?

我想通過 self.put_variable 追蹤畢業生。 有沒有辦法讓這成為可能? 或者另一種更新提供給被跟蹤模塊的參數的方法? 我的輸出畢業生是: 什么表明它沒有通過 self.variable_put 方法跟蹤 grads,因為 grads 到 W 都是零,而 b 顯然依賴於 W。 ...

您可以從該模塊的 nn.compact 中更新模塊的參數嗎? (自我修改網絡)

[英]Can you update parameters of a module from inside the nn.compact of that module? (self modifying networks)

我對亞麻很陌生,我想知道獲得這種行為的正確方法是什么: 其中 f 是一個 nn.module 實例。 其中 f 可能會通過多個操作來獲取 new_param 並且這些操作可能依賴於中間參數來產生它們的輸出。 所以基本上,有沒有一種方法可以訪問和更新從__call__中提供給 nn.module ...

計算 Flax NN 輸出到輸入的 Hessian 向量積

[英]Calculating the Hessian Vector Product of a Flax NN output wrt to the inputs

我試圖獲得輸出的二階導數 wrt 使用 Flax 構建的神經網絡的輸入。 網絡結構如下: 我可以通過在 grad 上使用 vmap 來獲得單導數: 但是,當我再次嘗試這樣做以獲得二階導數時: 我收到以下錯誤: 我嘗試使用 autodiff 食譜中的 hvp 定義,但參數是函數的輸入, ...

Pickle 在 jax 中更改類型

[英]Pickle changes type in jax

我有一個包含 jax numpy 數組的 flax struct 數據類。 當我腌制轉儲這個對象並再次加載它時,該數組不再是一個 jax numpy 數組,而是轉換為一個 numpy 數組,這里是重現它的代碼: 我不想要這種行為,我希望它保持原來的類型,這可能嗎? ...

對於神經網絡,亞麻比純 Jax 慢得多?

[英]Flax much slower than pure Jax for neural nentworks?

對於一個項目,我正在嘗試編寫一個非常簡單的 MLP 示例,但我注意到 flax 中的實現比純 jax 實現慢大約 20 倍。 我在這里做錯了什么? 產生 output: ...

flax (google) 和 dm-haiku (deepmind) 之間的主要區別是什么?

[英]What is the main difference between flax (google) and dm-haiku (deepmind)?

亞麻和dm-haiku之間的主要區別是什么? 從他們的描述中: Flax,一個用於 JAX 的神經網絡庫 Haiku,受 Sonnet 啟發的 JAX 神經網絡庫問題: 我應該選擇哪個基於 jax 的庫來實現,比如說DeepSpeech model(由 CNN 層 + LSTM 層 + FC 組 ...


 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM