Day61 ML Review - Class Imbalances
Use Other Metrics, Assign Different Class Weights, or Upsample the Minority Class
(explanation from S. Raschka and V. Mirjalili, Python Machine Learning: Machine Learning and Deep Learning with Python, scikit-learn, and TensorFlow 2, 3rd ed. Birmingham: Packt Publishing, 2023.)
In real-world situations, we often encounter imbalanced class problems when dealing with datasets, where examples from certain classes are over-represented. We can think of several domains where this may occur, such as spam filtering, fraud detection, or screening for diseases.
Use Other Metrics
In the previous discussions, we worked with the Breast Cancer Wisconsin dataset, where 90 percent of the patients were healthy. This means that simply by predicting the majority class (benign tumor) for all cases without utilizing any machine learning algorithm, we could achieve 90 percent accuracy on the test dataset. Therefore, if a model trained on such a dataset achieves approximately 90 percent test accuracy, it implies that the model has yet to gain meaningful insights from the dataset features.
When dealing with datasets that boast a 90% prediction accuracy, we must shift our focus toward alternative metrics for evaluating different models. Metrics such as precision, recall, and the ROC curve become crucial in determining the most suitable model for our specific application. For example, if our primary objective is to detect the majority of patients with cancer to recommend further screening, recall becomes the pivotal metric. Conversely, in spam filtering, where we aim to avoid labeling legitimate emails as spam, precision emerges as a more fitting metric, especially in cases where the system’s certainty is low.
Assign Larger Penalty
When evaluating machine learning models, it’s important to consider class imbalance since it can affect the learning algorithm during model training. Machine learning algorithms often optimize a reward or cost function based on the training examples they see, which can lead to bias towards the majority class. Essentially, the algorithm learns a model that prioritizes predictions for the most common class to minimize cost or maximize reward during training.
When dealing with imbalanced class proportions during model fitting, one approach is to assign a higher penalty to incorrect predictions on the minority class. In scikit-learn, adjusting this penalty is as simple as setting the class_weight
parameter to class_weight = 'balanced'
, which is available for most classifiers.
Upsample the Minority Class
Common strategies for addressing imbalance include upsampling the minority class, downsampling the majority class, and generating synthetic training examples. However, there is no single best solution that works well for all types of problems. Therefore, it is advisable to experiment with different strategies for a specific issue, assess the outcomes, and select the most suitable technique based on the results.
The scikit-learn library includes a straightforward resample
function that upsamples the minority class by generating new samples from the dataset with replacement. The code below takes the minority class(class 1
) from our imbalanced dataset and draws new samples from it repeatedly until it contains the same number of examples as the other class (class 0
):
>>> from sklearn.utils import resample
>>> print('Number of class 1 examples before: ', X_imb[y_imb==1].shape[0])
Number of class 1 examples before: 40
>>> X_upsampled, y_upsampled = resample(X_imb[y_imb == 1],
y_imb[y_imb == 1],
replace=True,
n_samples = X_imb[y_imb == 0].shape[0],
random_state=123)
>>> print('Number of class 1 examples after: ', X_upsampled.shape[0])
Number of class 1 examples after: 357
The results are shown as follows.
- Number of class 1 examples before: 40
- Number of class 1 examples after: 357
After resampling, we can then combine the original class 0
samples with the upsampled class 1
subset to achieve a balanced dataset.
>>> X_bal = np.vstack((X[y==0], X_upsampled))
>>> y_bal = np.hstack((y[y==0], y_upsampled))
Consequently, a majority vote prediction rule would only achieve 50 percent accuracy:
>>> y_pred = np.zeros(y_bal.shape[0])
>>> np.mean(y_pred == y_bal*100)
50
We can also downsample the majority class by removing training examples from the dataset. To downsample using the resample
function, we can simply swap the class 1
label with class 0
in the previous code example and vice versa.
Leave a comment