Neural Networks
TensorFlow 2.0
MNIST Dataset
Machine Learning
Python Programming

Custom Neural Network Implementation on MNIST using Tensorflow 2.0?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

MNIST is small enough that you can learn TensorFlow 2 by building the training loop yourself instead of hiding everything behind model.fit. That makes it a good dataset for understanding the mechanics of forward propagation, loss computation, gradients, and parameter updates.

Load and prepare the dataset

MNIST images are 28 x 28 grayscale digits. A simple dense network works fine if you flatten each image into a vector and scale pixel values into the 0 to 1 range.

python
1import tensorflow as tf
2
3(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
4
5x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
6x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0
7
8train_ds = (
9    tf.data.Dataset.from_tensor_slices((x_train, y_train))
10    .shuffle(10000)
11    .batch(128)
12)
13test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)

Using tf.data keeps the input pipeline explicit and efficient. Even when the model is custom, you should still lean on TensorFlow's dataset utilities instead of writing manual batching loops in Python.

Define the network with raw variables

If the goal is a truly custom implementation, define weights and biases yourself rather than relying on Dense layers. That makes the math visible.

python
1class MLP(tf.Module):
2    def __init__(self, input_dim=784, hidden_dim=128, num_classes=10):
3        super().__init__()
4        self.w1 = tf.Variable(tf.random.normal([input_dim, hidden_dim], stddev=0.1))
5        self.b1 = tf.Variable(tf.zeros([hidden_dim]))
6        self.w2 = tf.Variable(tf.random.normal([hidden_dim, num_classes], stddev=0.1))
7        self.b2 = tf.Variable(tf.zeros([num_classes]))
8
9    def __call__(self, x):
10        hidden = tf.nn.relu(tf.matmul(x, self.w1) + self.b1)
11        logits = tf.matmul(hidden, self.w2) + self.b2
12        return logits
13
14
15model = MLP()

This is only a two-layer multilayer perceptron, but it is enough to classify MNIST well and simple enough to understand end to end.

Write the training step with GradientTape

TensorFlow 2 records operations performed inside a GradientTape. After the forward pass, you ask the tape for gradients and apply them with an optimizer.

python
1loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
2optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
3train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
4test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
5
6
7@tf.function
8def train_step(images, labels):
9    with tf.GradientTape() as tape:
10        logits = model(images)
11        loss = loss_fn(labels, logits)
12
13    variables = [model.w1, model.b1, model.w2, model.b2]
14    gradients = tape.gradient(loss, variables)
15    optimizer.apply_gradients(zip(gradients, variables))
16    train_accuracy.update_state(labels, logits)
17    return loss
18
19
20@tf.function
21def test_step(images, labels):
22    logits = model(images)
23    test_accuracy.update_state(labels, logits)

That is the critical difference between a custom loop and high-level Keras training: you decide exactly what counts as the loss, which variables receive gradients, and when optimizer updates happen.

Run epochs and monitor results

A simple training loop looks like this:

python
1for epoch in range(3):
2    train_accuracy.reset_state()
3    test_accuracy.reset_state()
4    last_loss = 0.0
5
6    for images, labels in train_ds:
7        last_loss = train_step(images, labels)
8
9    for images, labels in test_ds:
10        test_step(images, labels)
11
12    print(
13        f"epoch={epoch + 1} "
14        f"loss={last_loss.numpy():.4f} "
15        f"train_acc={train_accuracy.result().numpy():.4f} "
16        f"test_acc={test_accuracy.result().numpy():.4f}"
17    )

On MNIST, this small model should reach a reasonable baseline quickly. The point is not to beat state-of-the-art accuracy. The point is to understand the mechanics well enough that more complex models later make sense.

When a custom implementation is worth it

For ordinary supervised classification, Keras layers and model.fit are faster to write and easier to maintain. A custom implementation becomes useful when you need one of these:

  • multiple optimizers or losses
  • gradient clipping or inspection
  • nonstandard update rules
  • research code where you want to see every tensor transformation

MNIST is ideal because training is fast and debugging is manageable. If the loop has a bug, you find out quickly.

Common Pitfalls

The most common mistake is forgetting to flatten the images before using matrix multiplication in a dense network. Raw MNIST samples arrive as 28 x 28, not as vectors.

Another mistake is using a softmax output manually and also setting from_logits=True in the loss. Either return raw logits and let the loss handle softmax internally, or apply softmax yourself and configure the loss accordingly.

People also forget to normalize pixel values. Training on raw 0 to 255 integers usually works worse and converges less cleanly.

Finally, do not mix Python-side loops and TensorFlow tensors carelessly. Keep the heavy computation in TensorFlow operations and use tf.function for the hot path once the code is correct.

Summary

  • MNIST is a good dataset for learning a manual TensorFlow 2 training loop.
  • A custom implementation can be built directly with tf.Variable, tf.Module, and GradientTape.
  • Return logits, compute loss, take gradients, and apply optimizer updates explicitly.
  • 'tf.data handles batching and shuffling even when the model is custom.'
  • Use a manual loop when you need visibility or control beyond what model.fit provides.

Course illustration
Course illustration

All Rights Reserved.