decision tree analysis
machine learning
data sampling
algorithm optimization
computational methods

is there any way to get samples under each leaf of a decision tree?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Yes, you can get the samples that end up under each leaf of a decision tree. In scikit-learn, the usual pattern is to run apply() on the training or evaluation data, which gives you the leaf node id for each sample, then group samples by that id. This is useful for model interpretation, debugging, and understanding where the tree is overfitting.

The Core Trick: Use apply()

Every sample in a scikit-learn decision tree ends at exactly one leaf. The apply() method returns the leaf index for each row.

python
1from sklearn.datasets import load_iris
2from sklearn.tree import DecisionTreeClassifier
3
4X, y = load_iris(return_X_y=True)
5
6tree = DecisionTreeClassifier(max_depth=3, random_state=42)
7tree.fit(X, y)
8
9leaf_ids = tree.apply(X)
10print(leaf_ids[:10])

At this point, leaf_ids[i] tells you which leaf sample X[i] landed in.

Group Samples by Leaf

Once you have the leaf ids, grouping is straightforward.

python
1from collections import defaultdict
2
3samples_by_leaf = defaultdict(list)
4
5for index, leaf_id in enumerate(leaf_ids):
6    samples_by_leaf[leaf_id].append(index)
7
8for leaf_id, sample_indices in samples_by_leaf.items():
9    print(f"leaf={leaf_id} samples={sample_indices[:5]}")

This version stores row indices. That is usually better than copying the raw samples immediately because:

  • indices are cheaper to store,
  • you can look up labels or metadata later,
  • you can reuse the same grouping for multiple arrays.

Retrieve the Actual Samples and Labels

If you want the underlying rows:

python
1for leaf_id, sample_indices in samples_by_leaf.items():
2    leaf_X = X[sample_indices]
3    leaf_y = y[sample_indices]
4    print(f"leaf={leaf_id} count={len(sample_indices)}")
5    print("first sample:", leaf_X[0])
6    print("labels:", leaf_y[:5])

This is helpful when you want to inspect what kinds of examples the tree clustered into the same terminal decision.

Leaf Statistics from the Tree Itself

The tree object also exposes structural information. For example, tree_.n_node_samples tells you how many training samples reached each node.

python
1tree_struct = tree.tree_
2
3for node_id in range(tree_struct.node_count):
4    left = tree_struct.children_left[node_id]
5    right = tree_struct.children_right[node_id]
6    is_leaf = left == right
7
8    if is_leaf:
9        print(
10            f"leaf={node_id} "
11            f"samples={tree_struct.n_node_samples[node_id]}"
12        )

This is useful when you only want counts, not the actual sample membership.

Use decision_path() for More Detail

If you need not only the leaf, but also the full path each sample took through the tree, use decision_path().

python
1path_matrix = tree.decision_path(X[:3])
2
3for i in range(3):
4    node_ids = path_matrix.indices[
5        path_matrix.indptr[i]:path_matrix.indptr[i + 1]
6    ]
7    print(f"sample {i} visited nodes {node_ids}")

This helps answer a richer question: not just “which leaf did the sample reach,” but “which sequence of decisions got it there.”

Training Data Versus New Data

Be explicit about which dataset you are analyzing.

  • Using the training set shows how the fitted tree partitioned the data it learned from.
  • Using validation or test data shows how new samples distribute across leaves.

That distinction matters. A leaf with many training samples but almost no validation samples may indicate a branch that is too specialized.

Practical Interpretation Uses

Leaf-level sample inspection is useful for:

  • diagnosing overfitting,
  • explaining model behavior to stakeholders,
  • finding mislabeled or unusual samples,
  • extracting rule-based cohorts from a tree.

For example, if one leaf contains only two training examples and predicts a class with high confidence, that is often a sign the tree is memorizing rather than generalizing.

Common Pitfalls

  • Looking only at leaf sample counts and forgetting to recover the actual row indices when interpretation needs specific examples.
  • Assuming the leaf ids are meaningful labels rather than internal node ids assigned by the tree structure.
  • Mixing training and test samples in one grouping without tracking which dataset each sample came from.
  • Ignoring decision_path() when the real question is about rule traversal rather than final leaf membership.
  • Treating tiny leaf sample counts as normal when they may be evidence of overfitting.

Summary

  • Use tree.apply(X) to get the leaf id for each sample.
  • Group row indices by leaf id to recover the samples that end under each leaf.
  • Use tree.tree_.n_node_samples when you only need counts.
  • Use decision_path() when you need the full route through the tree, not just the terminal node.
  • Inspect leaf membership on both training and validation data for better model interpretation.

Course illustration
Course illustration

All Rights Reserved.