Complete Guide to Cross-Validation
This guide will explore the ins and outs of cross-validation, examine its different methods, and discuss why it matters in today's data science and machine learning processes.

Image by Author
Machine learning models often need lots of data, but how they work with new data in real-time is crucial. Cross-validation is a way to test how well a model works by splitting the data into parts, training the model on some, and testing it on the rest. This helps spot overfitting and underfitting, giving an idea of how the model might do in real-world situations.
This guide will take you through the basics, types, and best ways to use cross-validation to make your machine learning work better.
Prerequisites
Before starting with practical cross-validation, make sure you have a good understanding of the following:
- Machine learning basics: Learn about concepts like overfitting and underfitting and how to measure how well a model works
- Python skills: Be competent with the basics of Python and using tools such as Scikit-learn, Pandas, and NumPy
- Preparing your data: Know how to split your data into training and testing sets, and the reasons why we do this
To follow along with our examples, ensure that you've got Python on your computer with these libraries:
pip install numpy pandas scikit-learn matplotlib
What is Cross-Validation?
Let's start at the beginning: Cross-validation is one of the most widely used data resampling methods to assess the generalization ability of a predictive model and to prevent overfitting. Unlike simple train-test splits, it provides a more comprehensive understanding by rotating the training and testing sets. This helps ensure that every data point has the opportunity to be tested and contributes to a reliable performance metric.
Key Points of Cross-Validation:
- Evaluate model performance consistently
- Minimize bias by testing on diverse subsets of data
- Optimize hyperparameters through repeated validation cycles
Cross-validation comes in various types, and each type is in different data structures and techniques. Let’s check out the most commonly used techniques.
1. K-Fold Cross-Validation
Process:
- The dataset is split into k subsets (folds)
- The model is trained on k-1 folds and tested on the remaining folds
- This process repeats for each fold, ensuring that every subset is used for testing
- The final performance metric is the average of all test runs
Advantages:
- Works well for most datasets
- Reduces variance by averaging results
Considerations:
- Choosing the right value for k (commonly 5 or 10) is essential
Code Example:
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np
# Example data
X = np.random.rand(100, 5) # 100 samples, 5 features
y = np.random.randint(0, 2, 100) # Binary target variable
kf = KFold(n_splits=5)
model = LogisticRegression()
accuracies = []
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
print("Average Accuracy:", np.mean(accuracies))
2. Stratified K-Fold Cross-Validation
Stratified k-fold cross-validation is similar to plain vanilla k-fold, except that it ensures that each fold has the same class distribution as the whole dataset. This makes it particularly ideal for imbalanced datasets.
Code Example:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
accuracies = []
for train_index, test_index in skf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
print("Stratified Average Accuracy:", np.mean(accuracies))
Leave-One-Out Cross-Validation
In every run, leave-one-out cross-validation (LOOC) uses one data point for testing and the rest for training. This process is extremely thorough but computationally expensive for large datasets.
Code Example:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
accuracies = []
for train_index, test_index in skf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
print("Stratified Average Accuracy:", np.mean(accuracies))
Time Series Cross-Validation
Time series cross-validation possesses the following characteristics:
- Specifically designed for time-dependent data
- Training is performed in earlier periods, and testing is done in later periods
- Helps maintain the temporal sequence of the data
Code Example:
from sklearn.model_selection import TimeSeriesSplit
tscv = TimeSeriesSplit(n_splits=3)
accuracies = []
for train_index, test_index in tscv.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
print("Time Series Cross-Validation Accuracy:", np.mean(accuracies))
Group K-Fold Cross-Validation
Here is what sets group k-fold cross-validation apart from its cross-validation counterparts:
- Ensures that groups of data points (e.g., from the same user or batch) are either entirely in the training or testing set
- It also prevents data leakage in grouped datasets is performed in earlier periods, and testing is done in later periods
Code Example:
from sklearn.model_selection import GroupKFold
# Simulating grouped data
groups = np.random.randint(0, 5, len(X)) # 5 unique groups
gkf = GroupKFold(n_splits=5)
accuracies = []
for train_index, test_index in gkf.split(X, y, groups):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
print("Group K-Fold Cross-Validation Accuracy:", np.mean(accuracies))
Benefits of Cross-Validation
Why is it that we are undertaking this resampling process for model-building? Here are some of the more important reasons:
- Improved Model Reliability: It gives more robust performance measures
- Prevent Overfitting: Apart from Overfitting, it tests the model several times on different data splits, thereby enhancing the ability of the model to generalize
- Optimized Hyper-parameters: It helps in better tuning of hyper-parameters for optimal performance
- Extensive Evaluation: It ensures the use of all data points for both training and testing
- Reduces Variance: Multiple train/test splits yield more reliable performance metrics
- Applicable Across Models: Useful for simple and complicated models
Best Practices
Here are some important best practices for using cross-validation:
- Choose the Right Cross-Validation Technique: The technique should correspond to the dataset and problem type
- Watch for Data Leakage: Test data should not leak into the training data
- Combined with Grid Search: Cross-validation can be used to optimize hyperparameters
- Balance Computational Cost and Thoroughness: Techniques like LOOCV are very expensive
- Use Visualizations: Plots will help in visualizing performance trends across folds
Visualization Example:
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import KFold
X, y = make_classification(n_samples=100,
n_features=2,
n_classes=2,
random_state=42,
n_informative=2,
n_redundant=0)
kf = KFold(n_splits=5)
plt.figure(figsize=(10, 6))
for i, (train_index, test_index) in enumerate(kf.split(X)):
plt.scatter(X[test_index, 0], X[test_index, 1], label=f'Fold {i + 1}')
plt.title('K-Fold Cross-Validation Splits')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.show()

The above visualization shows how the dataset is split into 5 different folds for cross-validation. Each color represents data points assigned to the test set in each fold.
Conclusion
Cross-validation is an essential tool in machine learning, ensuring that models generalize well and providing insights into model performance. With the right cross-validation techniques and best practices, you have one more tool in your kit for building robust, reliable models for real-world applications.
References & Further Reading
- Scikit-Learn Documentation: Official documentation for scikit-learn, covering cross-validation techniques and examples
- Python Data Science Handbook by Jake VanderPlas: A comprehensive guide for data science in Python, including model evaluation and validation
- Machine Learning Yearning by Andrew Ng: A free book that explains model evaluation strategies and best practices.
Shittu Olumide is a software engineer and technical writer passionate about leveraging cutting-edge technologies to craft compelling narratives, with a keen eye for detail and a knack for simplifying complex concepts. You can also find Shittu on Twitter.