繁体   English   中英

如何创建简单的3层神经网络并使用有监督的学习教它?

[英]How to create simple 3-layer neural network and teach it using supervised learning?

基于PyBrain的教程,我设法将以下代码拼凑在一起:

#!/usr/bin/env python2
# coding: utf-8

from pybrain.structure import FeedForwardNetwork, LinearLayer, SigmoidLayer, FullConnection
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer

n = FeedForwardNetwork()

inLayer = LinearLayer(2)
hiddenLayer = SigmoidLayer(3)
outLayer = LinearLayer(1)

n.addInputModule(inLayer)
n.addModule(hiddenLayer)
n.addOutputModule(outLayer)

in_to_hidden = FullConnection(inLayer, hiddenLayer)
hidden_to_out = FullConnection(hiddenLayer, outLayer)

n.addConnection(in_to_hidden)
n.addConnection(hidden_to_out)

n.sortModules()

ds = SupervisedDataSet(2, 1)
ds.addSample((0, 0), (0,))
ds.addSample((0, 1), (1,))
ds.addSample((1, 0), (1,))
ds.addSample((1, 1), (0,))

trainer = BackpropTrainer(n, ds)
# trainer.train()
trainer.trainUntilConvergence()

print n.activate([0, 0])[0]
print n.activate([0, 1])[0]
print n.activate([1, 0])[0]
print n.activate([1, 1])[0]

它应该学习XOR函数,但结果似乎很随机:

0.208884929522

0.168926515771

0.459452834043

0.424209192223

要么

0.84956138664

0.888512762786

0.564964077401

0.611111147862

您的方法存在四个问题,在阅读神经网络常见问题解答后都很容易识别:

  • 为什么要使用偏差/阈值? :你应该添加一个偏向节点。 缺乏偏见使得学习非常有限:由网络代表的分离超平面只能通过原点。 使用偏置节点,它可以自由移动并更好地适应数据:

     bias = BiasUnit() n.addModule(bias) bias_to_hidden = FullConnection(bias, hiddenLayer) n.addConnection(bias_to_hidden) 
  • 为什么不将二进制输入编码为0和1? :您的所有样本都位于样本空间的单个象限中。 移动它们分散在原点周围:

     ds = SupervisedDataSet(2, 1) ds.addSample((-1, -1), (0,)) ds.addSample((-1, 1), (1,)) ds.addSample((1, -1), (1,)) ds.addSample((1, 1), (0,)) 

    (相应地修复脚本末尾的验证代码。)

  • trainUntilConvergence方法使用验证工作,并执行类似于早期停止方法的操作 这对于如此小的数据集没有意义。 请改用trainEpochs 1000时代足以让网络学习这个问题:

     trainer.trainEpochs(1000) 
  • 什么学习率应该用于backprop? :调整学习速率参数。 这是你每次使用神经网络时都要做的事情。 在这种情况下,值0.1或甚至0.2显着提高学习速度:

     trainer = BackpropTrainer(n, dataset=ds, learningrate=0.1, verbose=True) 

    (注意verbose=True参数。在调整参数时,观察错误的行为是必不可少的。)

通过这些修复,我得到了给定网络的给定数据集的一致性和正确结果,并且误差小于1e-23

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM