TensorFlow
dataset extraction
confusion matrix
machine learning
data preprocessing

How to extract classes from prefetched dataset in Tensorflow for confusion matrix

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Prefetching in TensorFlow improves pipeline performance, but it does not change how labels are stored or extracted. If you want a confusion matrix, the usual workflow is still the same: iterate over the dataset, collect true labels, run the model to collect predicted labels, and then feed both arrays into tf.math.confusion_matrix or another metric tool.

Prefetch Does Not Hide the Labels

A prefetched dataset is still a tf.data.Dataset. The prefetch transformation only overlaps input work with model execution so the pipeline stays efficient. It does not remove or transform the class labels unless you already did that in earlier mapping steps.

For example:

python
1import tensorflow as tf
2
3features = tf.constant([[1.0], [2.0], [3.0], [4.0]])
4labels = tf.constant([0, 1, 1, 0])
5
6dataset = tf.data.Dataset.from_tensor_slices((features, labels))
7dataset = dataset.batch(2).prefetch(tf.data.AUTOTUNE)

Each element of this dataset is still a pair:

  • a batch of features
  • a batch of labels

The fact that prefetching is enabled does not change how you unpack those values.

Extract True Classes by Iterating the Dataset

The simplest way to gather labels for a confusion matrix is to iterate over the dataset and append the label batches.

python
1import tensorflow as tf
2
3true_batches = []
4
5for x_batch, y_batch in dataset:
6    true_batches.append(y_batch)
7
8y_true = tf.concat(true_batches, axis=0)
9print(y_true)

If your labels are one-hot encoded instead of integer class IDs, convert them first:

python
y_true = tf.argmax(y_true_one_hot, axis=1)

That is a common source of confusion. Confusion matrices usually expect class IDs, not one-hot vectors.

Collect Predictions in the Same Pass

For evaluation, you normally collect labels and predictions together:

python
1import tensorflow as tf
2
3model = tf.keras.Sequential([
4    tf.keras.layers.Input(shape=(1,)),
5    tf.keras.layers.Dense(2)
6])
7
8true_batches = []
9pred_batches = []
10
11for x_batch, y_batch in dataset:
12    logits = model(x_batch, training=False)
13    predictions = tf.argmax(logits, axis=1)
14
15    true_batches.append(y_batch)
16    pred_batches.append(predictions)
17
18y_true = tf.concat(true_batches, axis=0)
19y_pred = tf.concat(pred_batches, axis=0)
20
21cm = tf.math.confusion_matrix(y_true, y_pred)
22print(cm)

This works whether the dataset is prefetched or not. Prefetching only affects performance, not the extraction pattern.

If the model outputs class probabilities, tf.argmax is the usual step to convert them into predicted class IDs.

Keep Label Encoding Consistent

The confusion matrix is only meaningful if y_true and y_pred use the same class encoding. Common possibilities are:

  • integer labels such as 0, 1, 2
  • one-hot labels that need argmax
  • string labels that need mapping to IDs first

If your dataset yields strings or dictionary-based labels, normalize them before computing the matrix.

For example, with one-hot labels:

python
for x_batch, y_batch in dataset:
    y_ids = tf.argmax(y_batch, axis=1)

For binary models that output a single sigmoid value instead of multi-class logits, use thresholding instead:

python
predictions = tf.cast(model(x_batch, training=False) > 0.5, tf.int32)
predictions = tf.squeeze(predictions, axis=-1)

The right extraction step depends on how the dataset labels and model outputs are encoded.

Convert to NumPy Only If You Need Another Library

If you plan to use scikit-learn’s confusion matrix instead of TensorFlow’s, convert after concatenation:

python
1from sklearn.metrics import confusion_matrix
2
3cm = confusion_matrix(y_true.numpy(), y_pred.numpy())
4print(cm)

This is usually cleaner than converting batch by batch. Gather the tensors first, then convert once.

That also keeps most of the work in TensorFlow, which is simpler if you are already using a tf.data pipeline.

Common Pitfalls

The biggest mistake is assuming prefetch changes how labels are accessed. It does not. You still iterate the dataset and unpack the elements normally.

Another issue is mixing one-hot labels with integer predictions without converting them to the same representation first.

Developers also sometimes compute predictions batch by batch but forget to concatenate the batches before building the confusion matrix. The matrix needs full aligned vectors of true and predicted classes.

Finally, watch the output shape of the model. A softmax classifier, a sigmoid classifier, and a one-hot label dataset all require slightly different extraction logic.

Summary

  • Prefetching affects performance, not the basic way labels are extracted from a dataset.
  • Iterate the dataset, collect label batches, and concatenate them.
  • Convert one-hot labels or probability outputs into class IDs before building the confusion matrix.
  • Use tf.math.confusion_matrix or convert once to NumPy for scikit-learn if needed.
  • The important part is keeping true labels and predicted labels in the same class representation.

Course illustration
Course illustration

All Rights Reserved.