繁体   English   中英

R中的梯度下降的图向量

[英]Plot vectors of gradient descent in R

我已经在R中编写了梯度下降算法的代码,现在我正试图“绘制”矢量的路径。

我的轮廓图中有画点,但这是不正确的,因为没人知道首先发生了什么。

在我的算法中,我总是有一个先前的状态P =(Xi,Yi)和一个后来的状态L =(Xi + 1,Yi + 1),那么,如何在contourpersp图中绘制矢量PL?

我只用contour得到它,其中红点是收敛:

轮廓

对于persp

透视

谢谢大家!

编辑:

图形可以分别变废为宝:

f<-function(u,v){
  u*u*exp(2*v)+4*v*v*exp(-2*u)-4*u*v*exp(v-u)
}

x = seq(-2, 2, by = 0.5)
y = seq(-2, 2, by = 0.5)
z <- outer(x,y,f)
#Contour plot
contour(x,y,z)
#Persp plot
persp(x, y, z, phi = 25, theta = 55, xlim=c(-2,2), ylim=c(-2,2),
      xlab = "U", ylab = "V",
      main = "F(u,v)", col="yellow", ticktype = "detailed"
) -> res

有一个技巧使用persp绘制点,如?persp 通过使用trans3d ,您可以成功地将点和线放在透视图上。

f<-function(u,v){
  u*u*exp(2*v)+4*v*v*exp(-2*u)-4*u*v*exp(v-u)
}

x = seq(-2, 2, by = 0.5)
y = seq(-2, 2, by = 0.5)
z <- scale(outer(x,y,f))

view <- persp(x, y, z, phi = 30, theta = 30, xlim=c(-2,2), ylim=c(-2,2),
      xlab = "X", ylab = "Y", zlab = "Z", scale = FALSE,
      main = "F(u,v)", col="yellow", ticktype = "detailed")

set.seed(2)
pts <- data.frame(x = sample(x, 3),
                  y = sample(y, 3),
                  z = sample(z, 3))

points(trans3d(x = pts$x, y = pts$y, z = pts$z, pmat = view), pch = 16)
lines(trans3d(x = pts$x, y = pts$y, z = pts$z, pmat = view))

在此处输入图片说明

Himmelblau函数作为测试示例:

f <- function(x, y) { (x^2+y-11)^2 + (x+y^2-7)^2 }

其偏导数:

dx <- function(x,y) {4*x**3-4*x*y-42*x+4*x*y-14}
dy <- function(x,y) {4*y**3+2*x**2-26*y+4*x*y-22}

运行梯度下降:

# gradient descent parameters
num_iter <- 100
learning_rate <- 0.001
x_val <- 6
y_val <- 6

updates_x <- vector("numeric", length = num_iter)
updates_y <- vector("numeric", length = num_iter)
updates_z <- vector("numeric", length = num_iter)

# parameter updates
for (i in 1:num_iter) {

  dx_val = dx(x_val,y_val)
  dy_val = dy(x_val,y_val)

  x_val <- x_val-learning_rate*dx_val
  y_val <- y_val-learning_rate*dx_val
  z_val <- f(x_val, y_val)

  updates_x[i] <- x_val
  updates_y[i] <- y_val
  updates_z[i] <- z_val
}

绘图:

x <- seq(-6, 6, length = 100)
y <- x
z <- outer(x, y, f)

plt <- persp(x, y, z,
               theta = -50-log(i), phi = 20+log(i),
               expand = 0.5,
               col = "lightblue", border = 'lightblue',
               axes = FALSE, box = FALSE,
               ltheta = 60, shade = 0.90
  )

points(trans3d(updates_x[1:i], updates_y[1:i], updates_z[1:i],pmat = plt),
       col = c(rep('white', num_iter-1), 'blue'),
       pch = 16,
       cex = c(rep(0.5, num_iter-1), 1))

在此处输入图片说明

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM