train_test_split
StratifiedShuffleSplit
data splitting
machine learning
cross-validation

Differnce between train_test_split and StratifiedShuffleSplit

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Splitting data into training and testing sets is one of the first steps in any machine learning workflow. Scikit-learn provides train_test_split for quick, simple splits and StratifiedShuffleSplit for splits that preserve class distributions. Choosing the wrong one can produce misleading evaluation metrics, especially when your dataset has imbalanced classes.

This article explains how each method works, demonstrates them with code, and shows exactly when class-preserving splits matter.

How train_test_split Works

train_test_split randomly shuffles the data and divides it into two subsets according to a specified ratio. It does not consider the distribution of class labels.

python
1from sklearn.model_selection import train_test_split
2from sklearn.datasets import make_classification
3import numpy as np
4
5# Create a dataset with 1000 samples, 2 classes
6X, y = make_classification(n_samples=1000, n_classes=2,
7                           weights=[0.9, 0.1], random_state=42)
8
9X_train, X_test, y_train, y_test = train_test_split(
10    X, y, test_size=0.2, random_state=42
11)
12
13print(f"Training set class distribution: {np.bincount(y_train)}")
14print(f"Test set class distribution:     {np.bincount(y_test)}")

Typical output:

 
Training set class distribution: [718  82]
Test set class distribution:     [176  24]

The split is random, so the class proportions in the training and test sets may not match the original dataset. With a 90/10 class imbalance, random chance might give you a test set with 15% minority class or only 8%. For small datasets, this variance is even larger.

Parameters

  • test_size: Fraction or absolute number of samples for the test set.
  • train_size: Fraction or absolute number for the training set. If omitted, it is the complement of test_size.
  • random_state: Seed for reproducibility.
  • shuffle: Whether to shuffle before splitting (default True).
  • stratify: Optional. When set to y, it behaves like a stratified split (see below).

How StratifiedShuffleSplit Works

StratifiedShuffleSplit ensures that each split preserves the percentage of samples for each class. It is a cross-validation iterator, meaning it can produce multiple independent train/test splits.

python
1from sklearn.model_selection import StratifiedShuffleSplit
2import numpy as np
3
4X, y = make_classification(n_samples=1000, n_classes=2,
5                           weights=[0.9, 0.1], random_state=42)
6
7sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
8
9for train_idx, test_idx in sss.split(X, y):
10    X_train, X_test = X[train_idx], X[test_idx]
11    y_train, y_test = y[train_idx], y[test_idx]
12
13print(f"Training set class distribution: {np.bincount(y_train)}")
14print(f"Test set class distribution:     {np.bincount(y_test)}")

Typical output:

 
Training set class distribution: [716  84]
Test set class distribution:     [178  22]

The proportions in both sets now closely match the original 90/10 ratio. This is critical for imbalanced datasets because it ensures the test set is representative of the real-world class distribution.

Side-by-Side Comparison

Here is a direct comparison on a heavily imbalanced dataset with 5 classes.

python
1from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
2from sklearn.datasets import make_classification
3import numpy as np
4
5X, y = make_classification(n_samples=500, n_classes=5,
6                           n_informative=5,
7                           weights=[0.5, 0.2, 0.15, 0.1, 0.05],
8                           random_state=42)
9
10# Random split
11_, _, _, y_test_random = train_test_split(X, y, test_size=0.2, random_state=42)
12
13# Stratified split
14sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
15for _, test_idx in sss.split(X, y):
16    y_test_stratified = y[test_idx]
17
18print("Original distribution:   ", np.bincount(y))
19print("Random split test set:   ", np.bincount(y_test_random))
20print("Stratified split test set:", np.bincount(y_test_stratified))

The stratified split will mirror the original proportions much more closely, while the random split may over-represent or under-represent rare classes.

Multiple Splits for Cross-Validation

A major advantage of StratifiedShuffleSplit is that it can generate multiple independent splits, which is useful for repeated evaluation.

python
1from sklearn.model_selection import StratifiedShuffleSplit
2from sklearn.ensemble import RandomForestClassifier
3from sklearn.metrics import accuracy_score
4import numpy as np
5
6sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
7scores = []
8
9for train_idx, test_idx in sss.split(X, y):
10    X_train, X_test = X[train_idx], X[test_idx]
11    y_train, y_test = y[train_idx], y[test_idx]
12
13    clf = RandomForestClassifier(random_state=42)
14    clf.fit(X_train, y_train)
15    score = accuracy_score(y_test, clf.predict(X_test))
16    scores.append(score)
17
18print(f"Mean accuracy: {np.mean(scores):.4f} (+/- {np.std(scores):.4f})")

Each of the 5 splits preserves class proportions, giving you a more reliable estimate of model performance than a single random split.

The stratify Parameter in train_test_split

If you only need a single stratified split (not multiple), you can use the stratify parameter of train_test_split instead of creating a StratifiedShuffleSplit object.

python
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

This produces the same result as a single-split StratifiedShuffleSplit. Use this shorthand when you do not need repeated splits.

When to Use Each

Use train_test_split (without stratify) when:

  • Your classes are roughly balanced (within a few percentage points of each other).
  • You are doing a quick experiment and do not need precise class representation.
  • Your dataset is large enough that random sampling naturally preserves proportions.

Use StratifiedShuffleSplit (or train_test_split with stratify=y) when:

  • Your dataset has imbalanced classes.
  • You are working with a small dataset where random variation in class proportions could skew results.
  • You need multiple independent train/test splits for repeated evaluation.
  • Your evaluation metric is sensitive to class distribution (precision, recall, F1).

Common Pitfalls

  1. Ignoring class imbalance. Using plain train_test_split on a dataset with a 95/5 class split can produce a test set with zero samples from the minority class, making your metrics meaningless.
  2. Forgetting to pass y to split(). StratifiedShuffleSplit.split() requires both X and y. If you pass only X, the stratification has no labels to work with and raises an error.
  3. Confusing StratifiedShuffleSplit with StratifiedKFold. StratifiedKFold partitions the data into non-overlapping folds for k-fold cross-validation. StratifiedShuffleSplit creates random splits that may overlap across iterations. Use StratifiedKFold when you need every sample to appear in exactly one test fold.
  4. Using n_splits > 1 without iterating. If you set n_splits=5 but only take the first split, you are wasting the setup. Either use all splits or set n_splits=1.

Summary

train_test_split is the go-to method for quick, single random splits on balanced datasets. StratifiedShuffleSplit guarantees that class proportions are preserved in every split, which is essential for imbalanced data and small datasets. For a single stratified split, you can use train_test_split with the stratify parameter as a convenient shorthand. For repeated evaluation with class-preserving splits, StratifiedShuffleSplit with n_splits > 1 is the right tool.


Course illustration
Course illustration

All Rights Reserved.