简体   繁体   中英

R: How to plot the hyperplane and margins of an svm in ggplot2?

I'm following along Tibshirani's ISL text. I'm trying to plot the results of an SVM in ggplot2. I can get the points and the support vectors, but I can't figure out how to get the margins and hyperplane drawn for the 2D case. I Googled and checked the e1071 readme. A general, dynamic solution (applicable to a variety of SVM kernels,costs,etc.) would be great. Here is my MWE:

x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

df = dat; df
df = cbind(df, sv=rep(0,nrow(df)))
df[svmfit$index,]$sv = 1

ggplot(data=df,aes(x=x.1,y=x.2,group=y,color=y)) +     

Something like this: 在此处输入图片说明 (From Python's scikit-learn)

So you don't want to plot the support vectors right? Here's something very basic that works for your example, based on the plot.svm source code.


You can construct something much richer by taking a look at that source code.

x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

grid <- expand.grid(seq(min(dat[, 1]), max(dat[, 1]),length.out=100),                                                                                                         
                            seq(min(dat[, 2]), max(dat[, 2]),length.out=100)) 
names(grid) <- names(dat)[1:2]
preds <- predict(svmfit, grid)
df <- data.frame(grid, preds)
ggplot(df, aes(x = x.2, y = x.1, fill = preds)) + geom_tile()

Should output this:


Compare this to the plot.svm output:

plot(svmfit, dat)



If you want to reproduce the points as well, I've altered the above code slightly:

cols <- c('1' = 'red', '-1' = 'black')
tiles <- c('1' = 'magenta', '-1' = 'cyan')
shapes <- c('support' = 4, 'notsupport' = 1)
dat$support <- 'notsupport'
dat[svmfit$index, 'support'] <- 'support'

ggplot(df, aes(x = x.2, y = x.1)) + geom_tile(aes(fill = preds)) + 
  scale_fill_manual(values = tiles) +
  geom_point(data = dat, aes(color = y, shape = support), size = 2) +
  scale_color_manual(values = cols) +
  scale_shape_manual(values = shapes) +
  ggtitle('SVM classification plot')


The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM