MNIST
matplotlib
image visualization
Python
tutorial

Display MNIST image using matplotlib

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Displaying an MNIST digit with matplotlib is simple once the array shape is correct. MNIST images are grayscale 28 x 28 arrays, so the main job is to load one sample, make sure it is two-dimensional, and render it with a grayscale color map.

Load MNIST

A convenient source is the Keras dataset loader.

python
1import matplotlib.pyplot as plt
2from tensorflow.keras.datasets import mnist
3
4(x_train, y_train), (x_test, y_test) = mnist.load_data()
5
6print(x_train.shape)
7print(y_train.shape)

For the built-in loader, each image is already shaped as 28 x 28, which means it is ready for plotting.

Display One Image

python
1import matplotlib.pyplot as plt
2from tensorflow.keras.datasets import mnist
3
4(x_train, y_train), _ = mnist.load_data()
5
6index = 0
7image = x_train[index]
8label = y_train[index]
9
10plt.imshow(image, cmap="gray")
11plt.title(f"Label: {label}")
12plt.axis("off")
13plt.show()

cmap="gray" matters because MNIST is a single-channel grayscale dataset. Without it, matplotlib may choose an arbitrary default color map.

Display Several Digits in a Grid

A grid is useful for sanity checking labels or preprocessing.

python
1import matplotlib.pyplot as plt
2from tensorflow.keras.datasets import mnist
3
4(x_train, y_train), _ = mnist.load_data()
5
6fig, axes = plt.subplots(3, 3, figsize=(6, 6))
7
8for i, ax in enumerate(axes.flat):
9    ax.imshow(x_train[i], cmap="gray")
10    ax.set_title(f"y={y_train[i]}")
11    ax.axis("off")
12
13plt.tight_layout()
14plt.show()

This gives you a quick visual check that the data looks like handwritten digits rather than scrambled arrays.

If the Data Is Flattened

Some preprocessing pipelines flatten each image into a length-784 vector. In that case, reshape before plotting.

python
1import numpy as np
2import matplotlib.pyplot as plt
3
4flat_image = np.random.randint(0, 256, size=(784,), dtype=np.uint8)
5image = flat_image.reshape(28, 28)
6
7plt.imshow(image, cmap="gray")
8plt.axis("off")
9plt.show()

This reshape step is one of the most common missing pieces when people say the digit does not display correctly.

Useful Display Tweaks

For small images like MNIST, interpolation="nearest" can make the pixels easier to inspect.

python
plt.imshow(image, cmap="gray", interpolation="nearest")
plt.axis("off")
plt.show()

This is handy when you are debugging thresholding, normalization, or image centering.

A very practical debugging trick is to display the true label and a model prediction in the title after training. That helps you inspect misclassified digits quickly and notice patterns such as poor centering, heavy normalization artifacts, or classes the model often confuses. Visualization is not just for demos; it is one of the fastest ways to sanity-check a classification pipeline.

If your pipeline normalizes pixel values into the range from 0.0 to 1.0, matplotlib still displays them correctly as long as the array shape is right. Problems that look like rendering errors are often really preprocessing errors such as incorrect reshape order, inverted values, or accidental batching dimensions left on the sample.

Displaying several digits side by side is also a quick sanity check for shuffled labels and dataset corruption.

It is also useful for comparing raw input images with normalized or augmented versions during debugging.

Common Pitfalls

  • Forgetting to reshape flattened vectors back to 28 x 28.
  • Omitting cmap="gray" and getting a misleading colorized image.
  • Passing arrays with extra batch or channel dimensions when you only want to display one sample.
  • Confusing normalized pixel values with broken image rendering.
  • Forgetting plt.show() in environments where rendering is not automatic.

Summary

  • MNIST digits should be displayed as 28 x 28 grayscale images.
  • 'plt.imshow(image, cmap="gray") is the standard way to render one sample.'
  • Reshape flattened vectors of length 784 before plotting.
  • Use subplot grids to inspect multiple digits quickly.
  • Small display options such as interpolation="nearest" can make debugging easier.

Course illustration
Course illustration

All Rights Reserved.