简体   繁体   中英

Levenberg-Marquardt algorithm with GPU support

For the shallow neural net, the LM algorithm does amazingly well.

However, only MatLab and pyrenn (Python package) seems to have a robust implementation of it. A problem with both of these implementation is that they do not have GPU support for it. I also tried neupy (a python package) but it is not robust and fails when you try to train for longer epoch or large dataset. Do you guys know of a good LM python package for NN that can be trained using GPU?

I'm not really sure that such an implementation would be useful as for anything except for the most trivial networks, computing the Hessian is very impractical. However, I think that implementing Levenberg-Marquardt in pytorch for shallow networks isn't really that hard. At the very least, even if your implementation is not optimal you'll still get some speed ups compared to the CPU version. Here is a very quick and thus very imperfect/suboptimal example that runs on the GPU. The speedup is about 3x for inputs of size 3x20x20 on a 1080ti (I don't have free GPUs at the moment to test for larger inputs). However if you have a prior on the loss function ( eg if it is least squares and you know that the Hessian can be approximated as 2*Jacobian^t*Jacobian ) then some variant of the code below might be useful:

import torch
import numpy as np
import functools
import matplotlib.pyplot as plt

def LM(model,loss,n_iter=30):

    alpha=1e-3
    loss_hist=[]
    for i in range(n_iter):
        model.train()
        out=model(toy_input).unsqueeze(1)
        loss_out=loss(out)
        prev_loss=loss_out.item()
        gradients=torch.autograd.grad(loss_out, model.parameters(), create_graph=True)

        model.eval()
        Hessian, g_vector=eval_hessian(gradients, model)

        dx=-1(alpha*torch.eye(Hessian.shape[-1]).cuda()+Hessian).inverse().mm(g_vector).detach()

        cnt=0
        model.zero_grad()

        for p in model.parameters():

            mm=torch.Tensor([p.shape]).tolist()[0]
            num=int(functools.reduce(lambda x,y:x*y,mm,1))
            p.requires_grad=False
            p+=dx[cnt:cnt+num,:].reshape(p.shape)
            cnt+=num
            p.requires_grad=True


        out=model(toy_input).unsqueeze(1)
        loss_out=loss(out)

        if loss_out<prev_loss:
            print("Successful iteration")
            loss_hist.append(loss_out)
            alpha/=10
        else:
            print("Augmenting step size")
            alpha*=10
            cnt=0
            for p in model.parameters():

                mm=torch.Tensor([p.shape]).tolist()[0]
                num=int(functools.reduce(lambda x,y:x*y,mm,1))
                p.requires_grad=False
                p-=dx[cnt:cnt+num,:].reshape(p.shape)
                cnt+=num
                p.requires_grad=True

    return loss_hist 



def eval_hessian(loss_grad, model):
    cnt = 0
    for g in loss_grad:
        g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector,     g.contiguous().view(-1)])
        cnt = 1
    l = g_vector.size(0)
    hessian = torch.zeros(l, l).cuda()
    for idx in range(l):
        grad2rd = torch.autograd.grad(g_vector[idx], model.parameters(), create_graph=True)
        cnt = 0
        for g in grad2rd:
            g2 = g.contiguous().view(-1) if cnt == 0 else torch.cat([g2, g.contiguous().view(-1)])
            cnt = 1
        hessian[idx] = g2
    return hessian, g_vector.unsqueeze(1)

def toy_loss(vec):
    return vec.transpose(0,1).mm(vec)

class toy_model(torch.nn.Module):

    def __init__(self,in_c,width,height):

        super().__init__()

        self.cnv=torch.nn.Conv2d(in_c,1,3,1,padding=1)
        self.lin=torch.nn.Linear(1*width*height,16)

    def forward(self,tns):

        out=self.cnv(tns)
        out=self.lin(out.view(-1))
        return out

if __name__=="__main__":

    H=20
    W=20
    toy_input=torch.rand(1,3,H,W).cuda()
    toy_mdl=toy_model(3,W,H)
    toy_mdl.cuda()

    loss_hist=LM(toy_mdl,lambda x:toy_loss(x))

Note that I took the code for eval_hessian from here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM