Learning from Imbalanced Data



By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Source

Table of Contents

1. Imbalanced Data

  • Consider binary classification

  • Often the classes are highly imbalanced



  • Should I feel happy if the classifier gets 99.997% classification accuracy on test data?

1.1. True Definition of Imbalance Data?

  • Debatable ...

  • Scenario 1: 100,000 negative and 1,000 positive examples

  • Scenario 2: 10,000 negative and 10 positive examples

  • Scenario 3: 1,000 negative and 1 positive examples

  • Usually, imbalance is characterized by absolute rather than relative rarity

    • Finding needles in a haystack...

1.2. Minimizing Loss

  • Any model to minimize the loss, e.g.,


$$ \text{Classification}: \hat\omega = \arg \min_{\omega} \sum_{n=1}^{N} \ell \left(y_n,\omega^T x_n \right) $$


$\quad \;$ ... will usually get a high accuracy

  • However, it will be highly biased towards predicting the majority class
    • Thus accuary alone cannot be trusted as the evaluation measure if we care more about predicting minority class (say positive) correctly

1.3. Better Evaluation Measures

  • Precision: What fraction of positive predictions is truely positive


$$ P = \frac{\# \text{ example correctly predicted as positive}}{\# \text{ examples predicted as positive}}$$


  • Recall: What fraction of total positives are predicted as positives


$$ R = \frac{\# \text{ example correctly predicted as positive}}{\# \text{ total positive examples in the test set}}$$




  • Often there is a trade-off between precision and recall. Also there can be combined to yield other measures such as F1 score, AUC score, etc.

1.4. Dealing with Class Imbalance

  • Modifying the training data (the class distiribution)
    • Undersampling the majority class
    • Oversampling the minority class
    • Reweighting the examples
  • Modifying the learning model
    • Use loss functions customized to handle class imbalance
  • Reweighting can be also seen as a way to modify the loss function

2. Modifying the Training Data

2.1. Undersampling

Create a new training data set by:

  • including all $k$ "positive" examples
  • randomly picking $k$ "negative" examples



Throws away a lot of data/information, but efficient to train

2.2. Oversampling

Create a new training data set by:

  • including all $m$ "negative" examples
  • includ $m$ "positive" examples:
    • repeat each example a fixed number of times, or
    • sample with replacement



  • From the loss function's perspective, the repeated examples simply constribute multiple times to the loss function

  • Oversampling ususally tends to outperform undersampling because we are using more data to train the model

  • Some oversampling methods (SMOTE) are based on creating synthetic examples from the minority class

2.3. Reweighting Examples

Add costs/weights to the training set

  • "negative" examples get weight 1

  • "positive" examples get a much larger weight

Change learning algorithm to optimize weighted training error



  • Similar effect as oversampling but is more efficient (because there is no multiplicity of examples)

  • Also requires a classfier that can learn with weighted examples

3. Modifying the Loss Function

3.1. Loss Functions Customized for Imbalanced Data

  • Traditional loss functions have the form $\sum_{n=1}^{N} \ell \left( y_n, f(x_n)\right)$

  • Such loss functions look at positive and negative examples individually, so the majority class tends to overwhelm the minority class

  • Reweighting the loss function differently for different classes can be one way to handle class imbalance, e.g., $\sum_{n=1}^{N} C_{y_n} \ell \left( y_n, f(x_n)\right)$

  • Alternatively, we can loss functions that look at pairs of examples (a positive example $x_n^+$ and a negative example $x_m^-$). For example:


$$ \ell \left( f(x_n^+), f(x_m^-)\right) = \begin{cases} 0, & \text{if }\; f(x_n^+) > f(x_m^{-})\\ 1, & \text{otherwise} \end{cases} $$


  • These are called "pairwise" loss functions

  • Why is it a good loss function for imbalanced data?

3.2. Pairwise Loss Functions

In [1]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')