data splitting
train-test split
group-based splitting
machine learning
data preprocessing

How to generate a train-test-split based on a group id?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

If multiple rows belong to the same real-world entity, a random train-test split can leak information from training into testing. Splitting by group id fixes that by keeping every row from one group entirely in either the training set or the test set. This matters for patients, users, devices, sessions, stores, and any other case where rows inside a group are correlated.

Why Group-Based Splitting Matters

Assume one customer appears in ten rows. If five rows land in training and five in test, the model can partially memorize that customer’s pattern and look better than it really is. The measured accuracy is then optimistic because the test data is not truly independent.

A group-based split enforces a stronger rule:

  • each group id appears in only one split
  • evaluation is closer to predicting on unseen groups
  • leakage through repeated entities is reduced

This is often the correct setup for production use cases where the model will face entirely new users, patients, or devices.

Use GroupShuffleSplit

scikit-learn provides GroupShuffleSplit for a one-time randomized split that respects groups.

python
1import numpy as np
2from sklearn.model_selection import GroupShuffleSplit
3
4X = np.array([
5    [1.0], [2.0], [3.0],
6    [10.0], [11.0],
7    [20.0], [21.0], [22.0]
8])
9
10y = np.array([0, 0, 1, 1, 1, 0, 0, 1])
11groups = np.array([1, 1, 1, 2, 2, 3, 3, 3])
12
13splitter = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=42)
14train_idx, test_idx = next(splitter.split(X, y, groups=groups))
15
16print("train groups:", np.unique(groups[train_idx]))
17print("test groups:", np.unique(groups[test_idx]))

The important parameter is groups=groups. Without it, the splitter cannot protect group boundaries.

Verify the Split

Always verify that no group appears on both sides.

python
1train_groups = set(groups[train_idx])
2test_groups = set(groups[test_idx])
3
4print("overlap:", train_groups & test_groups)

The overlap should be an empty set. That small check catches many accidental mistakes in preprocessing code.

When You Need Cross-Validation

If you need repeated evaluation rather than one train-test split, use GroupKFold or a related grouped cross-validation strategy.

python
1from sklearn.model_selection import GroupKFold
2
3cv = GroupKFold(n_splits=3)
4for fold, (train_idx, test_idx) in enumerate(cv.split(X, y, groups=groups), start=1):
5    print(f"fold {fold}")
6    print(" train groups:", np.unique(groups[train_idx]))
7    print(" test groups:", np.unique(groups[test_idx]))

This keeps each group whole within each fold.

Think About Group Size and Class Balance

Group-aware splitting protects against leakage, but it can create new issues if groups are very uneven. One large group can dominate the test set, and class balance can drift if labels are concentrated within certain groups.

That means you should inspect:

  • number of groups in each split
  • number of rows in each split
  • label distribution in each split

If label balance also matters, newer grouped stratification tools such as StratifiedGroupKFold may be more appropriate than a plain grouped split.

A Pandas-Friendly Pattern

If your data starts in a DataFrame, the usage is the same.

python
1import pandas as pd
2from sklearn.model_selection import GroupShuffleSplit
3
4df = pd.DataFrame({
5    "feature": [1, 2, 3, 10, 11, 20, 21, 22],
6    "target": [0, 0, 1, 1, 1, 0, 0, 1],
7    "group_id": [1, 1, 1, 2, 2, 3, 3, 3],
8})
9
10splitter = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=42)
11train_idx, test_idx = next(splitter.split(df[["feature"]], df["target"], groups=df["group_id"]))
12
13train_df = df.iloc[train_idx]
14test_df = df.iloc[test_idx]

This keeps the workflow simple and explicit.

Common Pitfalls

  • Using train_test_split directly on grouped data and leaking information across splits.
  • Forgetting to pass the groups array into the splitter.
  • Assuming grouped splitting also preserves label balance automatically.
  • Ignoring highly uneven group sizes when choosing test_size or the number of folds.
  • Failing to verify that the train and test group sets are disjoint.

Summary

  • Split by group id when rows from the same entity are correlated.
  • 'GroupShuffleSplit is the usual tool for a one-time train-test split.'
  • 'GroupKFold is better when you need grouped cross-validation.'
  • Always check that train and test groups do not overlap.
  • Review group size and label balance, because leakage prevention does not guarantee a balanced split.

Course illustration
Course illustration

All Rights Reserved.