簡體   English   中英

為什么我不能用這個網絡和約束學習 XOR 函數?

[英]Why can't I learn XOR function with this network and constraints?

假設我有以下約束和網絡:

  1. 架構是固定的(見這張圖片)(注意沒有偏見)
  2. 隱藏層的激活函數是 ReLU
  3. 輸出層沒有激活函數(應該只返回它接收到的輸入的總和)。

我嘗試在 pytorch 中使用各種初始化方案和不同的數據集來實現這一點,但我失敗了(代碼在底部)。

我的問題是:

  1. 我的 NN 訓練過程有什么問題嗎?
  2. 這是一個可行的問題嗎? 如果是,如何?
  3. 如果這是可行的,我們仍然可以通過將權重限制在集合 {-1, 0, 1} 中來實現嗎?

代碼:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
import numpy as np

class Network(nn.Module):

     def __init__(self):

        super(Network, self).__init__()
        self.fc1 = nn.Linear(2,2,bias=False)
        self.fc2 = nn.Linear(2,1, bias=False)
        self.rl = nn.ReLU()

      def forward(self, x):
        x = self.fc1(x)
        x = self.rl(x)
        x = self.fc2(x)
        return x 

#create an XOR data set to train    
rng = np.random.RandomState(0)
X = rng.randn(200, 2)
y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0).astype('int32')

# test data set
X_test = np.array([[0,0],[0,1], [1,0], [1,1]])

train = data_utils.TensorDataset(torch.from_numpy(X).float(), \
                         torch.from_numpy(y).float())
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

test = torch.from_numpy(X_test).float()

# training the network
num_epoch = 10000

net = Network()
net.fc1.weight.data.clamp_(min=-1, max=1)
net.fc2.weight.data.clamp_(min=-1, max=1)

# define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters())

for epoch in range(num_epoch):
  running_loss = 0 # loss per epoch
  for (X, y)in train_loader:
    # make the grads zero
    optimizer.zero_grad()
    # forward propagate
    out = net(X)
    # calculate loss and update
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    running_loss += loss.data
  if epoch%500== 0:      
      print("Epoch: {0} Loss: {1}".format(epoch, running_loss))

損失沒有改善。 它在幾個時期后陷入了某個值(我不確定如何使這個可重現,因為我每次都得到不同的值)

net(test)返回一組與 XOR 輸出相差無幾的預測。

您需要在隱藏層和輸出層中使用非線性激活函數,例如 sigmoid。 因為 xor 不是線性可分的。還需要偏差。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM