Sunday, July 15, 2012

Plot classification regions in an SVM

I've been really busy. I can't really claim to be an SVM expert since my postgraduate work up to now did not deal very much with machine learning, though since late I've been working with SVM's and kernels.

In fact I have a very advanced technique that really boost the SVM performance, but that will be left until the paper is published.

One of the test I ran dealt with a modified XOR data problem. The problem consist of four groups drawn from bivariate normal distributions. They are assign two classes such that the groups of the same class are always separated from each other by some group of another class.
y=mvrnorm(50,c(3,5),Sigma=diag(c(0.5,1.5)))
y=rbind(y,mvrnorm(50,c(15,13),Sigma=diag(c(1.5,0.5))))
y=rbind(y,mvrnorm(50,c(7,5),Sigma=diag(c(0.5,1.5))))
y=rbind(y,mvrnorm(50,c(15,17),Sigma=diag(c(1.5,0.5))))
labels=c(rep(1,100),rep(-1,100))
I put the figure here so that the problem is clear, the explanation of how to get it follows.
K.svm=svm(Phi, labels, type="C", kernel="linear",probability=T)
X=as.matrix(expand.grid(list(x = seq(0, 20, length.out=100), y = seq(0, 20, length.out=100))))
# compute the kernel on X here
im=predict(K.svm, PhiX,scale=F)
im=matrix(as.numeric(im),nrow=100,byrow=F)
image(seq(0, 20, length.out=100),seq(0, 20, length.out=100),im,xlab="",ylab="",col=c("#FFFCCCFF","#FFF000FF"))#heat.colors(2))
points(y)
We see that we first train the SVM with the kernel features as explained in the previous post.
Then we create a grid spanning all the points of the region we are interested in painting and evaluate the trained SVM it there. Then we recompose the grid of classified points into a 2D plane and plot it along with the original points.

1 comment:

  1. hi good tuto ;) i have a problem with the plot function of e1071 :
    #this is my code
    samples<-sample(nrow(mydata),nrow(mydata)*0.6)
    dtrain<-mydata[samples,]
    dtest<-mydata[-samples,]
    sv<-svm(Crise~.,dtrain,method = "C-classification", kernel = "radial",cost = 10, gamma = 0.1)
    summary(sv)
    plot(sv,dtrain)
    thank u

    ReplyDelete