Machine Learning 201: Does Balancing Classes Improve Classifier Performance?

The author investigates if balancing classes improves performance for logistic regression, SVM, and Random Forests, and finds where it helps the performance and where it does not.

By Nina Zumel (Win-Vector).

It’s a folk theorem I sometimes hear from colleagues and clients: that you must balance the class prevalence before training a classifier. Certainly, I believe that classification tends to be easier when the classes are nearly balanced, especially when the class you are actually interested in is the rarer one. But I have always been skeptical of the claim that artificially balancing the classes (through resampling, for instance) always helps, when the model is to be run on a population with the native class prevalences.

On the other hand, there are situations where balancing the classes, or at least enriching the prevalence of the rarer class, might be necessary, if not desirable. Fraud detection, anomaly detection, or other situations where positive examples are hard to get, can fall into this case. In this situation, I’ve suspected (without proof) that SVM would perform well, since the formulation of hard-margin SVM is pretty much distribution-free. Intuitively speaking, if both classes are far away from the margin, then it shouldn’t matter whether the rare class is 10% or 49% of the population. In the soft-margin case, of course, distribution starts to matter again, but perhaps not as strongly as with other classifiers like logistic regression, which explicitly encodes the distribution of the training data.

So let’s run a small experiment to investigate this question.

Experimental Setup

We used the ISOLET dataset, available at the UCI Machine Learning repository. The task is to recognize spoken letters. The training set consists of 120 speakers, each of whom uttered the letters A-Z twice; 617 features were extracted from the utterances. The test set is another 30 speakers, each of whom also uttered A-Z twice.

Our chosen task was to identify the letter “n”. This target class has a native prevalence of about 3.8% in both test and training, and is to be identified from out of several other distinct co-existing populations. This is similar to a fraud detection situation, where a specific rare event has to be a population of disparate “innocent” events.

We trained our models against a training set where the target was present at its native prevalence; against training sets where the target prevalence was enriched by resampling to twice, five times, and ten times its native prevalence; and against a training set where the target prevalence was enriched to 50%. This replicates some plausible enrichment scenarios: enriching the rare class by a large multiplier, or simply balancing the classes. All training sets were the same size (N=2000). We then ran each model against the same test set (with the target variable at its native prevalence) to evaluate model performance. We used a threshold of 50% to assign class labels (that is, we labeled the data by the most probable label). To get a more stable estimate of how enrichment affected performance, we ran this loop ten times and averaged the results for each model type.

We tried three model types:

  • cv.glmnet from R package glmnet: Regularized logistic regression, with alpha=0 (L2 regularization, or ridge). cv.glmnet chooses the regularization penalty by cross-validation.
  • randomForest from R package randomForest: Random forest with the default settings (500 trees, nvar/3, or about 205 variables drawn at each node).
  • ksvm from R pacakge kernlab: Soft-margin SVM with the radial basis kernel and C=1

Since there are many ways to resample the data for enrichment, here’s how I did it. The target variable is assumed to be TRUE/FALSE, with TRUE as the class of interest (the rare one). dataf is the data frame of training data, N is the desired size of the enriched training set, and prevalence is the desired target prevalence.

makePrevalence = function(dataf, target, 
                          prevalence, N) {
  # indices of T/F
  tset_ix = which(dataf[[target]])
  others_ix = which(!dataf[[target]])
  ntarget = round(N*prevalence)
  heads = sample(tset_ix, size=ntarget, 
  tails = sample(others_ix, size=(N-ntarget), 
  dataf[c(heads, tails),]

Training at the Native Target Prevalence

Before we run the full experiment, let’s look at how each of these three modeling approaches does when we fit models the obvious way — where the training and test sets have the same distribution:

## [1] "Metrics on training data"
## accuracy precision   recall specificity         label
##   0.9985 1.0000000 0.961039     1.00000      logistic
##   1.0000 1.0000000 1.000000     1.00000 random forest
##   0.9975 0.9736842 0.961039     0.99896           svm
## [1] "Metrics on test data"
##  accuracy precision    recall specificity         label
## 0.9807569 0.7777778 0.7000000   0.9919947      logistic
## 0.9717768 1.0000000 0.2666667   1.0000000 random forest
## 0.9846055 0.7903226 0.8166667   0.9913276           svm

We looked at four metrics. Accuracy is simply the fraction of datums classified correctly. Precision is the fraction of datums classified as positive that really were; equivalently, it’s an estimate of the conditional probability of a datum being in the positive class, given that it was classified as positive. Recall (also called sensitivity or the true positive rate) is the fraction of positive datums in the population that were correctly identified. Specificity is the true negative rate, or one minus the false positive rate: the number of negative datums correctly identified as such.

As the table above shows, random forest did perfectly on the training data, and the other two did quite well, too, with nearly perfect precision/specificity and high recall. However, random forest’s recall plummeted on the hold-out set, to 27%. The other two models degraded as well (logistic regression more than SVM), but still manage to retain decent recall, along with good precision and specificity. Random forest also has the lowest accuracy on the test set (although 97% still looks pretty good — another reason why accuracy is not always a good metric to evaluate classifiers on. In fact, since the target prevalence in the data set is only about 3.8%, a model that always returned FALSE would have an accuracy of 96.2%!).

One could argue that if precision is the goal, then random forest is still in the running. However, remember that the goal here is to identify a rare event. In many such situations (like fraud detection) one would expect that high recall is the most important goal, as long as precision/specificity are still reasonable.

Let’s see if enriching the target class prevalence during training improves things.