简体   繁体   中英

Neural Network convergence speed (Levenberg-Marquardt) (MATLAB)

I was trying to approximate a function (single input and single output) with an ANN. Using MATLAB toolbox I could see that with 5 or more neurons in the hidden layer, I can achieve a very nice result. So I am trying to do it manually.

Calculations: As the network has only one input and one output, the partial derivative of the error (e=do, where 'd' is the desired output and 'o' is the actual output) in respect to a weigth which connects a hidden neuron j to the output neuron, will be -hj (where hj is the output of a hidden neuron j); The partial derivative of the error in respect to output bias will be -1; The partial derivative of the error in respect to a weight which connects the input to a hidden neuron j will be -woj*f'*i, where woj is the hidden neuron j output weigth, f' is the tanh() derivative and 'i' is the input value; Finally, the partial derivative of the error in respect to hidden layer bias will be the same as above (in respect to input weight) except that here we dont have the input: -woj*f'

The problem is: the MATLAB algorithm always converge faster and better. I can achieve the same curve as MATLAB does, but my algorithm requires much more epochs. I've tried to remove pre and postprocessing functions from MATLAB algorithm. It still converges faster. I've also tried to create and configure the network, and extract weight/bias values before training so I could copy them to my algorithm to see if it converges faster but nothing changed (is the weight/bias initialization inside create/configure or train function?).

Does the MATLAB algorithm have some kind of optimizations inside the code? Or may be this difference only in the organization of the training set and weight/bias initialization?

In case one wants to look my code, here is the main loop which makes the training:

Err2 = N;
epochs = 0;
%compare MSE of error2
while ((Err2/N > 0.0003) && (u < 10000000) && (epochs < 100))
    epochs = epochs+1;
    Err = 0;
    %input->hidden weight vector
    wh = w(1:hidden_layer_len);
    %hidden->output weigth vector
    wo = w((hidden_layer_len+1):(2*hidden_layer_len));
    %hidden bias
    bi = w((2*hidden_layer_len+1):(3*hidden_layer_len));
    %output bias
    bo = w(length(w));
    %start forward propagation
    for i=1:N
        %take next input value
        x = t(i);
        %propagate to hidden layer
        neth = x*wh + bi;
        %propagate through neurons
        ij = tanh(neth)';
        %propagate to output layer
        neto = ij*wo + bo;
        %propagate to output (purelin)
        output(i) = neto;
        %calculate difference from target (error)
        error(i) = yp(i) - output(i);

        %Backpropagation:

        %tanh derivative
        fhd = 1 - tanh(neth').*tanh(neth');
        %jacobian matrix
        J(i,:) = [-x*wo'.*fhd -ij -wo'.*fhd -1];

        %SSE (sum square error)
        Err = Err + 0.5*error(i)*error(i);
    end

    %calculate next error with updated weights and compare with old error

    %start error2 from error1 + 1 to enter while loop
    Err2 = Err+1;
    %while error2 is > than old error and Mu (u) is not too large
    while ((Err2 > Err) && (u < 10000000))
        %Weight update
        w2 = w - (((J'*J + u*eye(3*hidden_layer_len+1))^-1)*J')*error';
        %New Error calculation

        %New weights to propagate
        wh = w2(1:hidden_layer_len);
        wo = w2((hidden_layer_len+1):(2*hidden_layer_len));
        %new bias to propagate
        bi = w2((2*hidden_layer_len+1):(3*hidden_layer_len));
        bo = w2(length(w));
        %calculate error2
        Err2 = 0;
        for i=1:N
            %forward propagation again
            x = t(i);
            neth = x*wh + bi;
            ij = tanh(neth)';
            neto = ij*wo + bo;
            output(i) = neto;
            error2(i) = yp(i) - output(i);

            %Error2 (SSE)
            Err2 = Err2 + 0.5*error2(i)*error2(i);
        end

        %compare MSE from error2 with a minimum
        %if greater still runing
        if (Err2/N > 0.0003)
            %compare with old error
            if (Err2 <= Err)
                %if less, update weights and decrease Mu (u)
                w = w2;
                u = u/10;
            else
                %if greater, increment Mu (u)
                u = u*10;
            end
        end
    end
end

It's not easy to know the exact implementation of the Levenberg Marquardt algorithm in Matlab. You may try to run the algorithm one iteration at a time, and see if it is identical to your algorithm. You can also try other implementations, such as, http://www.mathworks.com/matlabcentral/fileexchange/16063-lmfsolve-m--levenberg-marquardt-fletcher-algorithm-for-nonlinear-least-squares-problems , to see if the performance can be improved. For simple learning problems, convergence speed may be a matter of learning rate. You might simply increase the learning rate to get faster convergence.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM