简体   繁体   中英

How to find 2 parameters with gradient descent method in Python?

I have a few lines of code which doesn't converge. If anyone has an idea why, I would greatly appreciate. The original equation is written in def f(x,y,b,m) and I need to find parameters b,m.

  np.random.seed(42)
  x = np.random.normal(0, 5, 100)
  y = 50 + 2 * x + np.random.normal(0, 2, len(x))

  def f(x, y, b, m):
      return (1/len(x))*np.sum((y - (b + m*x))**2) # it is supposed to be a sum operator

  def dfb(x, y, b, m): # partial derivative with respect to b
      return b - m*np.mean(x)+np.mean(y)

  def dfm(x, y, b, m): # partial derivative with respect to m
      return np.sum(x*y - b*x - m*x**2)

  b0 = np.mean(y)
  m0 = 0
  alpha = 0.0001
  beta = 0.0001
  epsilon = 0.01

  while True:

      b = b0 - alpha * dfb(x, y, b0, m0)
      m = m0 - alpha * dfm(x, y, b0, m0)

      if np.sum(np.abs(m-m0)) <= epsilon and np.sum(np.abs(b-b0)) <= epsilon:
          break
      else:
          m0 = m
          b0 = b
      print(m, f(x, y, b, m))

Both derivatives got some signs mixed up:

def dfb(x, y, b, m): # partial derivative with respect to b
  # return b - m*np.mean(x)+np.mean(y)
  #          ^-------------^------ these are incorrect
  return b + m*np.mean(x) - np.mean(y)

def dfm(x, y, b, m): # partial derivative with respect to m
  #      v------ this should be negative
  return -np.sum(x*y - b*x - m*x**2)

In fact, these derivatives are still missing some constants:

  • dfb should be multiplied by 2
  • dfm should be multiplied by 2/len(x)

I imagine that's not too bad because the gradient is scaled by alpha anyway, but it could make the speed of convergence worse.

If you do use the correct derivatives, your code will converge after one iteration:

def dfb(x, y, b, m): # partial derivative with respect to b
  return 2 * (b + m * np.mean(x) - np.mean(y))

def dfm(x, y, b, m): # partial derivative with respect to m
  # Used `mean` here since (2/len(x)) * np.sum(...)
  # is the same as 2 * np.mean(...)
  return -2 * np.mean(x * y - b * x - m * x**2)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM