Keras
gradient accumulation
large batch sizes
machine learning
deep learning

How to accumulate gradients for large batch sizes in Keras

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Gradient accumulation lets you simulate a larger batch size by summing gradients across several smaller mini-batches before applying an optimizer step. This is useful when the effective batch size you want does not fit into GPU memory in one forward and backward pass.

Why Gradient Accumulation Works

Suppose you want an effective batch size of 256, but memory only allows batches of 32. You can:

  • run 8 mini-batches of size 32
  • accumulate their gradients
  • apply one optimizer update after those 8 steps

That approximates the effect of training on a batch of 256, while keeping peak memory usage close to the smaller batch size.

The key implementation detail is that the loss should usually be normalized consistently so the accumulated gradient scale matches the intended effective batch.

A Custom Keras Model With Accumulation

In modern Keras, the cleanest approach is to override train_step.

python
1import tensorflow as tf
2
3
4class AccumModel(tf.keras.Model):
5    def __init__(self, accumulation_steps, *args, **kwargs):
6        super().__init__(*args, **kwargs)
7        self.accumulation_steps = accumulation_steps
8        self.step_counter = tf.Variable(0, trainable=False, dtype=tf.int64)
9        self.gradient_accumulators = []
10
11    def build(self, input_shape):
12        super().build(input_shape)
13        self.gradient_accumulators = [
14            tf.Variable(tf.zeros_like(var), trainable=False)
15            for var in self.trainable_variables
16        ]
17
18    def train_step(self, data):
19        x, y = data
20
21        with tf.GradientTape() as tape:
22            y_pred = self(x, training=True)
23            loss = self.compiled_loss(y, y_pred) / self.accumulation_steps
24
25        gradients = tape.gradient(loss, self.trainable_variables)
26
27        for acc, grad in zip(self.gradient_accumulators, gradients):
28            acc.assign_add(grad)
29
30        self.step_counter.assign_add(1)
31
32        if tf.equal(self.step_counter % self.accumulation_steps, 0):
33            self.optimizer.apply_gradients(
34                zip(self.gradient_accumulators, self.trainable_variables)
35            )
36            for acc in self.gradient_accumulators:
37                acc.assign(tf.zeros_like(acc))
38
39        self.compiled_metrics.update_state(y, y_pred)
40        return {m.name: m.result() for m in self.metrics}

This model divides the loss by accumulation_steps, accumulates gradients, and applies them only after the chosen number of mini-batches.

A Minimal Usage Example

python
1inputs = tf.keras.Input(shape=(10,))
2x = tf.keras.layers.Dense(32, activation="relu")(inputs)
3outputs = tf.keras.layers.Dense(1)(x)
4
5model = AccumModel(accumulation_steps=4, inputs=inputs, outputs=outputs)
6model.compile(optimizer="adam", loss="mse", metrics=["mae"])
7
8x_train = tf.random.normal((128, 10))
9y_train = tf.random.normal((128, 1))
10
11model.fit(x_train, y_train, batch_size=8, epochs=2)

Here the mini-batch size is 8, but the effective batch size is closer to 32 because four gradient steps are accumulated before each optimizer update.

Things That Still Change

Gradient accumulation is not always identical to true large-batch training. Differences can appear when your model uses:

  • batch normalization
  • gradient clipping
  • adaptive optimizer internals
  • data augmentation randomness per mini-batch

So the technique is extremely useful, but it is best understood as a close approximation rather than a magical equivalence in every detail.

When to Use It

Use gradient accumulation when:

  • GPU memory is the limiting factor
  • you want a larger effective batch for stability or throughput reasons
  • changing model size is less desirable than changing update frequency

It is especially common in large language models, high-resolution vision training, and any workload where activation memory dominates.

Common Pitfalls

The biggest pitfall is forgetting to scale the loss or gradients properly. If you accumulate raw gradients without accounting for the number of mini-batches, the effective update magnitude can be too large.

Another common mistake is thinking gradient accumulation fixes all large-batch concerns. It solves memory pressure, but it does not erase every optimization or normalization difference between small and truly large batches.

Developers also sometimes forget to zero the accumulation buffers after applying gradients, which causes updates to include stale gradients from previous cycles.

Summary

  • Gradient accumulation simulates a larger effective batch size using multiple smaller mini-batches.
  • In Keras, overriding train_step is a clean way to implement it.
  • Divide the loss consistently so the accumulated update scale matches the intended batch behavior.
  • The method reduces memory pressure but is not always perfectly identical to true large-batch training.
  • Reset accumulated gradients after each optimizer step to avoid corrupt updates.

Course illustration
Course illustration

All Rights Reserved.