Transfer Learning
MNIST Dataset
Machine Learning
Deep Learning
Neural Networks

How to do transfer learning for MNIST dataset?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Transfer learning on MNIST is a little unusual because MNIST is so small and simple that training from scratch often already works well. Even so, transfer learning can still be useful as a practical exercise in feature reuse, especially when you adapt a pretrained image model to grayscale digits by resizing the images and replacing the classifier head.

Understand What You Are Transferring

Transfer learning usually means taking a model pretrained on a larger dataset, keeping some of its learned feature extractor layers, and retraining the final layers for the new task.

For MNIST, that often means:

  • loading a pretrained backbone such as MobileNetV2 or ResNet
  • adapting MNIST images to the expected input shape
  • replacing the final classification head
  • optionally fine-tuning some deeper layers later

The dataset is simple, but the workflow is still representative of real transfer learning practice.

Prepare MNIST for a Pretrained Vision Backbone

Most pretrained image models expect three-channel images with larger spatial dimensions than 28 x 28. So you usually resize MNIST and repeat the grayscale channel into three channels.

python
1import tensorflow as tf
2
3(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
4
5x_train = x_train[..., None].astype("float32") / 255.0
6x_test = x_test[..., None].astype("float32") / 255.0
7
8def preprocess(image, label):
9    image = tf.image.resize(image, [96, 96])
10    image = tf.image.grayscale_to_rgb(image)
11    return image, label
12
13train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(preprocess).batch(64)
14test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(preprocess).batch(64)

This preprocessing step is what makes MNIST compatible with an ImageNet-style backbone.

Build the Transfer Learning Model

Use a pretrained base model without its original classifier, freeze it initially, and add a small task-specific head.

python
1import tensorflow as tf
2
3base_model = tf.keras.applications.MobileNetV2(
4    input_shape=(96, 96, 3),
5    include_top=False,
6    weights="imagenet"
7)
8base_model.trainable = False
9
10inputs = tf.keras.Input(shape=(96, 96, 3))
11x = base_model(inputs, training=False)
12x = tf.keras.layers.GlobalAveragePooling2D()(x)
13outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
14
15model = tf.keras.Model(inputs, outputs)
16model.compile(
17    optimizer="adam",
18    loss="sparse_categorical_crossentropy",
19    metrics=["accuracy"]
20)
21
22model.fit(train_ds, epochs=3, validation_data=test_ds)

That is the standard first stage: reuse the pretrained feature extractor and train only the new classifier head.

Fine-Tune Carefully If Needed

If you want to push performance or experiment with deeper adaptation, unfreeze some or all of the backbone after the new head has stabilized.

python
1base_model.trainable = True
2
3model.compile(
4    optimizer=tf.keras.optimizers.Adam(1e-5),
5    loss="sparse_categorical_crossentropy",
6    metrics=["accuracy"]
7)
8
9model.fit(train_ds, epochs=2, validation_data=test_ds)

Use a smaller learning rate during fine-tuning. Otherwise the pretrained weights can be damaged too quickly.

Know When Transfer Learning Is Overkill

MNIST is so easy that a small custom CNN trained from scratch often performs extremely well. That means transfer learning on MNIST is more about learning the technique than about achieving a uniquely strong result.

Still, the exercise is valuable because it teaches the exact mechanics you will use later on more complex datasets where transfer learning matters much more.

Common Pitfalls

  • Forgetting to resize MNIST images and adapt them to three channels for a pretrained backbone.
  • Fine-tuning immediately instead of first training a small replacement head.
  • Using too large a learning rate after unfreezing the pretrained base.
  • Expecting transfer learning on MNIST to always beat a simple custom CNN in a dramatic way.
  • Treating the workflow as identical to scratch training instead of respecting the frozen-versus-fine-tuned stages.

Summary

  • Transfer learning on MNIST works by adapting the images to a pretrained model's expected input shape.
  • Freeze the pretrained backbone first and train a new classifier head.
  • Fine-tune later with a smaller learning rate if needed.
  • MNIST is simple, so transfer learning is often more educational than necessary.
  • The workflow is still useful practice for harder image datasets where transfer learning shines.

Course illustration
Course illustration

All Rights Reserved.