TensorFlow
datasets
data splitting
machine learning
data preprocessing

How do I split Tensorflow datasets?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Splitting data into training, validation, and test sets is easy in principle, but the exact technique depends on where the dataset came from. With TensorFlow, the cleanest options are either to split a tf.data.Dataset with take and skip or to use predefined split support when loading from TensorFlow Datasets.

Split a tf.data.Dataset with take and skip

If you already have a tf.data.Dataset, the most direct way to split it is to shuffle it once, then take and skip fixed counts.

python
1import tensorflow as tf
2
3seed = 123
4size = 100
5
6raw = tf.data.Dataset.range(size)
7shuffled = raw.shuffle(size, seed=seed, reshuffle_each_iteration=False)
8
9train_size = 70
10val_size = 15
11
12test_size = size - train_size - val_size
13
14train_ds = shuffled.take(train_size)
15val_ds = shuffled.skip(train_size).take(val_size)
16test_ds = shuffled.skip(train_size + val_size).take(test_size)
17
18print(train_ds.cardinality().numpy())
19print(val_ds.cardinality().numpy())
20print(test_ds.cardinality().numpy())

The important detail is reshuffle_each_iteration=False. Without that, the split order can change each epoch, which is not what you want when defining fixed train, validation, and test partitions.

Use Cardinality Carefully

This technique assumes you know the dataset size. For finite datasets created from arrays or ranges, that is straightforward. For streaming datasets or datasets with unknown cardinality, counting first may require materializing or inspecting the source in another way.

When cardinality is known, percentage-based splits are easy:

python
count = int(raw.cardinality().numpy())
train_size = int(count * 0.8)
val_size = int(count * 0.1)

You should compute these sizes once and then reuse them consistently in the pipeline.

TensorFlow Datasets Has Native Split Support

If you are loading data with tensorflow_datasets, the split can often be expressed directly in the loader call.

python
1import tensorflow_datasets as tfds
2
3train_ds, val_ds, test_ds = tfds.load(
4    "mnist",
5    split=["train[:80%]", "train[80%:90%]", "train[90%:]"],
6    as_supervised=True,
7)

This is often cleaner than loading the whole training set and then manually slicing it yourself. It also keeps the split definition close to the dataset source, which improves readability.

If the dataset already provides official train and test splits, a common pattern is to split only the training partition further and keep the official test set untouched.

Keep Labels and Features Aligned

When the dataset is built from tensors or NumPy arrays, create the dataset from aligned tuples before splitting. That way the partitioning happens on complete examples rather than on features and labels independently.

python
1import numpy as np
2import tensorflow as tf
3
4x = np.random.randn(100, 4).astype("float32")
5y = (x.sum(axis=1) > 0).astype("int32")
6
7ds = tf.data.Dataset.from_tensor_slices((x, y))
8ds = ds.shuffle(len(x), seed=123, reshuffle_each_iteration=False)
9
10train_ds = ds.take(80)
11val_ds = ds.skip(80).take(10)
12test_ds = ds.skip(90)

This avoids one of the worst preprocessing mistakes: splitting features and labels separately and accidentally destroying correspondence.

Split Before Expensive Augmentation

A useful rule is to split first and augment second. If you apply random augmentation before defining the train, validation, and test partitions, you risk leaking transformed variants of the same source examples across boundaries. That can make validation look better than it really is.

Keep the raw partition stable, then apply training-only augmentation to the training dataset.

Common Pitfalls

Shuffling without a fixed seed and then calling take and skip can make the split change between runs. Set a seed and disable reshuffling when defining stable partitions.

Using take and skip on an ordered dataset without shuffling can create biased splits if the source is sorted by class, time, or user.

Splitting features and labels separately breaks alignment and corrupts the dataset.

Summary

  • For tf.data.Dataset, use shuffle, then take and skip.
  • For tensorflow_datasets, prefer native split expressions when available.
  • Keep dataset order stable while defining the split.
  • Split before augmentation and always keep features and labels aligned.

Course illustration
Course illustration

All Rights Reserved.