簡體   English   中英

將帶有 ReLU 的神經網絡擬合到多項式函數

[英]Fitting a neural network with ReLUs to polynomial functions

出於好奇,我試圖將帶有修正線性單元的神經網絡擬合為多項式函數。 例如,我想看看神經網絡提出函數f(x) = x^2 + x的近似值有多容易(或困難)。 下面的代碼應該可以做到,但似乎什么也沒學到。 當我跑

using Base.Iterators: repeated
ENV["JULIA_CUDA_SILENT"] = true
using Flux
using Flux: throttle
using Random

f(x) = x^2 + x
x_train = shuffle(1:1000)
y_train = f.(x_train)
x_train = hcat(x_train...)

m = Chain(
    Dense(1, 45, relu),
    Dense(45, 45, relu),
    Dense(45, 1),
    softmax
)

function loss(x, y) 
    Flux.mse(m(x), y)
end

evalcb = () -> @show(loss(x_train, y_train))
opt = ADAM()

@show loss(x_train, y_train)

dataset = repeated((x_train, y_train), 50)

Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))

println("Training finished")

@show m([20])

它返回

loss(x_train, y_train) = 2.0100101f14
loss(x_train, y_train) = 2.0100101f14
loss(x_train, y_train) = 2.0100101f14
Training finished
m([20]) = Float32[1.0]

這里的任何人都知道我如何使網絡適合f(x) = x^2 + x

您的試驗似乎有一些問題,主要與您如何使用優化器和處理輸入有關——Julia 或 Flux 沒有問題。 提供的解決方案確實可以學習,但絕不是最佳的。

  • 在回歸問題上激活 softmax 輸出是沒有意義的。 Softmax 用於分類問題,其中模型的輸出表示概率,因此應該在區間 (0,1) 上。 很明顯,您的多項式具有此區間之外的值。 在像這樣的回歸問題中通常會有線性輸出激活。 這意味着在 Flux 中不應在輸出層上定義輸出激活。
  • 數據的形狀很重要。 train! 計算loss(d...)梯度,其中ddata一個批次。 在您的情況下,一個小批量包含 1000 個樣本,並且同一批次重復 50 次。 神經網絡通常使用較小的批次進行訓練,但使用較大的樣本集。 在我提供的代碼中,所有批次都包含不同的數據。
  • 對於訓練神經網絡,通常建議對您的輸入進行歸一化。 您的輸入采用 1 到 1000 之間的值。我的示例應用了一個簡單的線性變換來獲取正確范圍內的輸入數據。
  • 歸一化也可以應用於輸出。 如果輸出很大,這可能會導致(太大)梯度和權重更新。 另一種方法是大量降低學習率。
using Flux
using Flux: @epochs
using Random

normalize(x) = x/1000
function generate_data(n)
    f(x) = x^2 + x
    xs = reduce(hcat, rand(n)*1000)
    ys = f.(xs)
    (normalize(xs), normalize(ys))
end
batch_size = 32
num_batches = 10000
data_train = Iterators.repeated(generate_data(batch_size), num_batches)
data_test = generate_data(100)


model = Chain(Dense(1,40, relu), Dense(40,40, relu), Dense(40, 1))
loss(x,y) = Flux.mse(model(x), y)

opt = ADAM()
ps = Flux.params(model)
Flux.train!(loss, ps, data_train, opt , cb = () -> @show loss(data_test...))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM