简体   繁体   中英

Gradient Descent algorithm taking long time to complete - Efficiency - Python

I am trying to implement the gradient descent algorithm using python and following is my code,

def grad_des(xvalues, yvalues, R=0.01, epsilon = 0.0001, MaxIterations=1000):
    xvalues= np.array(xvalues)
    yvalues = np.array(yvalues)
    length = len(xvalues)
    alpha = 1
    beta = 1
    converged = False
    i=0
    cost = sum([(alpha + beta*xvalues[i] - yvalues[i])**2 for i in range(length)]) / (2 * length)
    start_time = time.time()
    while not converged:      
        alpha_deriv = sum([(alpha + beta*xvalues[i] - yvalues[i]) for i in range(length)]) / (length)
        beta_deriv =  sum([(alpha + beta*xvalues[i] - yvalues[i])*xvalues[i] for i in range(length)]) / (length)
        alpha = alpha - R * alpha_deriv
        beta = beta - R * beta_deriv
        new_cost = sum( [ (alpha + beta*xvalues[i] - yvalues[i])**2 for i in range(length)] )  / (2*length)
        if abs(cost - new_cost) <= epsilon:
            print 'Converged'
            print 'Number of Iterations:', i
            converged = True
        cost = new_cost
        i = i + 1      
        if i == MaxIterations:
            print 'Maximum Iterations Exceeded'
            converged = True
    print "Time taken: " + str(round(time.time() - start_time,2)) + " seconds"
    return alpha, beta

This code is working fine. But the problem is, it is taking more than 25 seconds for approximately for 600 iterations. I feel this is not efficient enough and I tried converting it to a array before doing the calculations. That did reduce the time from 300 to 25 seconds. Still I feel it can be reduced. Can anybody help me in improving this algorithm?

Thanks

As I commented I can't reproduce the slowness, however here are some potential issues:

  1. It looks like length does not change, but you are repeatedly invoking range(length) . In Python 2.x, range creates a list, and doing this repeatedly can slow things down (object creation is not cheap.) Use xrange (or import a Py3-compatible iterator range from six or future ) and create the range once up front rather than each time.

  2. i is being reused here in a way that could cause problems. You're trying to use it as the overall iteration count, but each of your list comprehensions that uses i will overwrite i in the scope of the function, which means that the "iteration" count will always end up as length - 1 .

The lowest hanging fruit that I can see is in vectorization. You have a lot of list comprehensions; they're faster than for loops but have nothing on proper usage of numpy arrays.

def grad_des_vec(xvalues, yvalues, R=0.01, epsilon=0.0001, MaxIterations=1000):
    xvalues = np.array(xvalues)
    yvalues = np.array(yvalues)
    length = len(xvalues)
    alpha = 1
    beta = 1
    converged = False
    i = 0
    cost = np.sum((alpha + beta * xvalues - yvalues)**2) / (2 * length)
    start_time = time.time()
    while not converged:
        alpha_deriv = np.sum(alpha + beta * xvalues - yvalues) / length
        beta_deriv = np.sum(
            (alpha + beta * xvalues - yvalues) * xvalues) / length
        alpha = alpha - R * alpha_deriv
        beta = beta - R * beta_deriv
        new_cost = np.sum((alpha + beta * xvalues - yvalues)**2) / (2 * length)
        if abs(cost - new_cost) <= epsilon:
            print('Converged')
            print('Number of Iterations:', i)
            converged = True
        cost = new_cost
        i = i + 1
        if i == MaxIterations:
            print('Maximum Iterations Exceeded')
            converged = True
    print("Time taken: " + str(round(time.time() - start_time, 2)) + " seconds")
    return alpha, beta

For comparison

In[47]: grad_des(xval, yval)
Converged
Number of Iterations: 198
Time taken: 0.66 seconds
Out[47]: 
(0.28264882215511067, 0.53289263416071131)

In [48]: grad_des_vec(xval, yval)
Converged
Number of Iterations: 198
Time taken: 0.03 seconds
Out[48]: 
(0.28264882215511078, 0.5328926341607112)

That's about a factor 20 speed up (xval and yval were both 1024 element arrays.).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM