Tuesday, March 20, 2012

Multiclass SVM with e1071

When dealing with multi-class classification using the package e1071 for R, which encapsulates LibSVM, one faces the problem of correctly predicting values, since the predict function doesn't seem to deal effectively with this case. In fact, testing the very example that comes in the svm help (?svm on the R command line), one sees the failing performance of the function (albeit working with a correctly fitted model).

data(iris)
attach(iris)

model <- svm(Species ~ ., data = iris)

x <- subset(iris, select = -Species)
y <- Species
model <- svm(x, y, probability = TRUE)

This model is correctly fitted. However

pred <- predict(model, x)
Does not show the correct values

> pred     1      2      3      4      5      6      7      8      9     10     11     12     13     14     15 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    16     17     18     19     20     21     22     23     24     25     26     27     28     29     30 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    31     32     33     34     35     36     37     38     39     40     41     42     43     44     45 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    46     47     48     49     50     51     52     53     54     55     56     57     58     59     60 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    61     62     63     64     65     66     67     68     69     70     71     72     73     74     75 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    76     77     78     79     80     81     82     83     84     85     86     87     88     89     90 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
    91     92     93     94     95     96     97     98     99    100    101    102    103    104    105 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
   106    107    108    109    110    111    112    113    114    115    116    117    118    119    120 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
   121    122    123    124    125    126    127    128    129    130    131    132    133    134    135 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
   136    137    138    139    140    141    142    143    144    145    146    147    148    149    150 
setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa 
Levels: setosa versicolor virginica
We see that every prediction is setosa, even though the dataset is equally divided between the three classes (setosa, versicolor and virginica). It seems that something went wrong with the direct prediction, but one way to overcome this problem is to use the predicted probabilities, that seem to be well computed:

pred <- predict(model, x, decision.values = TRUE, probability = TRUE)
Observe how they look

> attr(pred, "probabilities")         setosa  versicolor   virginica
1   0.980325653 0.011291686 0.008382661
2   0.972927977 0.018061300 0.009010723
3   0.979044327 0.011921880 0.009033793
...
48  0.977140826 0.013050710 0.009808464
49  0.977831001 0.013359834 0.008809164
50  0.980099521 0.011501036 0.008399444
51  0.017740468 0.954734399 0.027525133
52  0.010394167 0.973918376 0.015687457
...
97  0.009263806 0.986123276 0.004612918
98  0.008503724 0.988168405 0.003327871
99  0.025068812 0.965415124 0.009516064
100 0.007514580 0.987584706 0.004900714
101 0.012482541 0.002502134 0.985015325
...
149 0.013669944 0.017618659 0.968711397
150 0.010205071 0.140882630 0.848912299
Now we see that the most probable class is indeed the ground truth and we can correctly classify with the following function
predsvm<-function(model,newdata)
{
  prob<-attr(predict(model, newdata, probability = TRUE),"probabilities")
  n<-dim(prob)[1]
  m<-dim(prob)[2]
 
  me<-which(prob==apply(prob,1,max))
  return(as.factor(model$labels[floor((me-1)/n)+1]))
}
One might also program the following function, that deals with the way the support vector coefficients are stored in the model object, in model$coefs and model$rho:
## Linear Kernel function
K <- function(i,j) crossprod(i,j)

predsvm <- function(object, newdata) {
  ## compute start-index
  start <- c(1, cumsum(object$nSV)+1)
  start <- start[-length(start)]

  ## compute kernel values
  kernel <- sapply (1:object$tot.nSV,
                    function (x) K(object$SV[x,], newdata))

  ## compute raw prediction for classifier (i,j)
  predone <- function (i,j) {
    ## ranges for class i and j:
    ri <- start[i] : (start[i] + object$nSV[i] - 1)
    rj <- start[j] : (start[j] + object$nSV[j] - 1)
    
    ## coefs for (i,j):
    coef1 <- object$coefs[ri, j-1]
    coef2 <- object$coefs[rj, i]

    ## return raw values:
    crossprod(coef1, kernel[ri]) + crossprod(coef2, kernel[rj])
  }

  ## compute votes for all classifiers
  votes <- rep(0,object$nclasses)
  c <- 0 # rho counter
  for (i in 1 : (object$nclasses - 1))
    for (j in (i + 1) : object$nclasses)
      if (predone(i,j) > object$rho[c <- c + 1])
        votes[i] <- votes[i] + 1
      else
        votes[j] <- votes[j] + 1

  ## return winner (index with max. votes)
  object$levels[which(votes %in% max(votes))[1]]
}

3 comments:

  1. Thanks for the blog. It really helped.

    ReplyDelete
  2. you might want to emphasize on the fact that we need to set 'probability = TRUE' for both training the model & in 'predict' api. First time I saw ur blog I missed the point that we should set it while training the model too. Took me bit more focus to get it right :)

    ReplyDelete
  3. Nice tutorial. I've 40,000 training and 40,000 testing dataset having 1024 attributes having 397 classes, I want to use this concept. Does it work?

    ReplyDelete