简体   繁体   中英

Can someone check what is wrong with my xor neural network code

I've been trying to create a XOR neural network but the outputs would always converge to a certain value (like 1, or 0, or 0.5) for all inputs. This is my latest attempt:

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) {
        double[][] trainingInputs = {
                {1, 1},
                {1, 0},
                {0, 1},
                {1, 1}
        };
        double[] targetOutputs = {0, 1, 1, 0};
        NeuralNetwork network = new NeuralNetwork();
        System.out.println("Training");
        for(int i=0; i<40; i++) {
            network.train(trainingInputs, targetOutputs);
        }
        for(double[] inputs : trainingInputs) {
            double output = network.feedForward(inputs);
            System.out.println(inputs[0] + " - " + inputs[1] + " : " + output);
        }
    }
}

class Neuron {
    private ArrayList<Synapse> inputs; // List di sinapsi collegate al neurone
    private double output; // output del neurone
    private double derivative; // derivata dell'output
    private double weightedSum; // somma ponderata del peso delle sinapsi e degli output collegati
    private double error; // errore
    public Neuron() {
        inputs = new ArrayList<Synapse>();
        error = 0;
    }
    // Aggiunge una sinpapsi
    public void addInput(Synapse input) {
        inputs.add(input);
    }

    public List<Synapse> getInputs() {
        return this.inputs;
    }

    public double[] getWeights() {
        double[] weights = new double[inputs.size()];

        int i = 0;
        for(Synapse synapse : inputs) {
            weights[i] = synapse.getWeight();
            i++;
        }

        return weights;
    }

    private void calculateWeightedSum() {
        weightedSum = 0;
        for(Synapse synapse : inputs) {
            weightedSum += synapse.getWeight() * synapse.getSourceNeuron().getOutput();
        }
    }

    public void activate() {
        calculateWeightedSum();
        output = sigmoid(weightedSum);
        derivative = sigmoidDerivative(output);
    }

    public double getOutput() {
        return this.output;
    }

    public void setOutput(double output) {
        this.output = output;
    }

    public double getDerivative() {
        return this.derivative;
    }

    public double getError() {
        return error;
    }

    public void setError(double error) {
        this.error = error;
    }

    public double sigmoid(double weightedSum) {
        return 1 / (1 + Math.exp(-weightedSum));
    }

    public double sigmoidDerivative(double output) {
        return output / (1 - output);
    }
}

class Synapse implements Serializable {

    private Neuron sourceNeuron; // Neurone da cui origina la sinapsi
    private double weight; // Peso della sinapsi

    public Synapse(Neuron sourceNeuron) {
        this.sourceNeuron = sourceNeuron;
        this.weight = Math.random() - 0.5;
    }

    public Neuron getSourceNeuron() {
        return sourceNeuron;
    }

    public double getWeight() {
        return weight;
    }

    public void adjustWeight(double deltaWeight) {
        this.weight += deltaWeight;
    }
}

class NeuralNetwork implements Serializable {
    Neuron[] input;
    Neuron[] hidden;
    Neuron output;
    double learningRate = 0.1;
    public NeuralNetwork() {
        input = new Neuron[2];
        hidden = new Neuron[2];
        output = new Neuron();
        for(int i=0; i<2; i++) {
            input[i] = new Neuron();
        }
        for(int i=0; i<2; i++) {
            hidden[i] = new Neuron();
        }
        for(int i=0; i<2; i++) {
            Synapse s = new Synapse(hidden[i]);
            output.addInput(s);
        }
        for(int i=0; i<2; i++) {
            for(int j=0; j<2; j++) {
                Synapse s = new Synapse(input[j]);
                hidden[i].addInput(s);
            }
        }
    }
    public void setInput(double[] inputVal) {
        for(int i=0; i<2; i++) {
            input[i].setOutput(inputVal[i]);
        }
    }
    public double feedForward(double[] inputVal) {
        setInput(inputVal);
        for(int i=0; i<2; i++) {
            hidden[i].activate();
        }
        output.activate();
        return output.getOutput();
    }
    public void train(double[][] trainingInputs, double[] targetOutputs) {
        for(int i=0; i<4; i++) {
            double[] inputs = trainingInputs[i];
            double target = targetOutputs[i];
            double currentOutput = feedForward(inputs);
            double delta = 0;
            double neuronError = 0;
            for(int j=0; j<2; j++) {
                Synapse s = output.getInputs().get(j);
                neuronError = output.getDerivative() * (target - currentOutput);
                delta = learningRate * s.getSourceNeuron().getOutput() * neuronError;
                output.setError(neuronError);
                s.adjustWeight(delta);
            }
            for(int j=0; j<2; j++) {
                for(int k=0; k<2; k++) {
                    Synapse s = hidden[j].getInputs().get(k);
                    Synapse s1 = output.getInputs().get(j);
                    delta = learningRate * s.getSourceNeuron().getOutput() * hidden[j].getDerivative() * s1.getWeight() * output.getError();
                    s.adjustWeight(delta);
                }
            }
        }
    }
}

I've found the backpropagation algorithm from the implementation of somebody else from github and tried using it but i either get outputs around 0.50 or just NaN. I don't get if I'm using a wrong algorithm, if I've implemented it in a wrong way or something else.

The algorithm I'm using goes like that: First I find the error of the neuron itself:

if it's the output neuron then neuronError = (derivative of the output neuron) * (expected output - actual output)

if it's a hidden neuron then neuronError = (derivative of the hidden neuron) * (neuronError of the output neuron) * (weight of the synapse that goes from the hidden neuron to the output neuron)

then deltaWeight = learningRate * (neuronError of the neuron the synapse starts from) * (output of the neuron the synapse starts from)

At last I add deltaWeight to the previous weight.

Sorry for the long text, if you won't read through the code, can you at least tell me if my algorithm is correct? Thank you

Your sigmoid derivative was wrong, it should be as follows:

public double sigmoidDerivative(double output) {
        return output * (1 - output);
    }
}

As I said in my comment, you have {1, 1} twice in your train input, so change one with {0, 0}.

Finally, increase the number of iterations from 40 to 100,000.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM