machine learning
Keras
memory leak
deep learning
model training

Keras occupies an indefinitely increasing amount of memory for each epoch

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

If memory usage appears to grow every epoch during Keras training, the cause is not always a true framework leak. The common causes are graph accumulation, callbacks or logs retaining data, dataset pipelines caching too much, generators holding references, or simply TensorFlow’s allocator reserving memory in a way that looks like a leak from the outside.

First Distinguish A Leak From Reserved Memory

TensorFlow often reserves memory aggressively, especially on GPU. That can make system monitors show high usage even when the framework intends to reuse that memory.

A real leak usually has one of these patterns:

  • every epoch increases host memory and it never stabilizes
  • every repeated model creation increases memory because old graphs stay alive
  • callbacks or saved histories grow with training duration

A simple one-time jump that stays flat is usually allocation strategy, not a leak.

Common Real Causes

The most frequent causes are:

  • creating new models in a loop without clearing old ones
  • calling fit repeatedly inside code that also accumulates tensors or histories
  • custom callbacks storing full predictions or batch outputs every epoch
  • 'tf.data caching or prefetch pipelines holding more data than expected'
  • Python references preventing garbage collection

The right fix depends on which of these is happening.

A Safe Pattern For Repeated Model Creation

If you rebuild models many times, clear the old graph between runs.

python
1import tensorflow as tf
2
3for _ in range(3):
4    tf.keras.backend.clear_session()
5
6    model = tf.keras.Sequential([
7        tf.keras.layers.Dense(32, activation="relu", input_shape=(10,)),
8        tf.keras.layers.Dense(1)
9    ])
10
11    model.compile(optimizer="adam", loss="mse")

Use clear_session() between separate model lifecycles, not as a routine inside every epoch of one normal training run.

Example Of A Self-Inflicted Leak

This callback stores every prediction for every epoch in memory.

python
1import tensorflow as tf
2
3class BadCallback(tf.keras.callbacks.Callback):
4    def __init__(self):
5        self.saved = []
6
7    def on_epoch_end(self, epoch, logs=None):
8        preds = self.model.predict(self.validation_data[0], verbose=0)
9        self.saved.append(preds)

The model may look like it leaks, but the real issue is the callback retaining large arrays indefinitely.

A safer callback stores summary statistics instead of full tensors.

python
1class BetterCallback(tf.keras.callbacks.Callback):
2    def __init__(self):
3        self.means = []
4
5    def on_epoch_end(self, epoch, logs=None):
6        self.means.append(float(logs.get("loss", 0.0)))

Watch The Input Pipeline

A tf.data.Dataset.cache() call can intentionally hold the entire dataset in memory. That is sometimes correct, but it can also be the hidden reason memory grows after the first pass.

Likewise, custom Python generators may keep references to large batches or decoded data structures if they are written carelessly.

Practical Debugging Steps

Use a short checklist:

  1. train for a few epochs and see whether memory stabilizes
  2. disable custom callbacks
  3. simplify the input pipeline
  4. run on CPU only once to separate host and GPU behavior
  5. rebuild the script in a clean process and compare runs

That process tells you whether the issue is allocator behavior, pipeline retention, or repeated graph creation.

Common Pitfalls

The most common mistake is calling clear_session() inside a normal epoch loop while still using the same model. That is not the right fix and can break training logic.

Another mistake is assuming Keras is leaking when TensorFlow is simply keeping GPU memory reserved for reuse.

A third issue is forgetting that Python containers such as lists, histories, and callback fields can keep large arrays alive long after an epoch ends.

Summary

  • Memory growth per epoch can come from retained Python objects, graph accumulation, dataset caching, or allocator behavior.
  • A one-time reserved-memory jump is not the same as a real leak.
  • Use clear_session() between separate model creations, not as an every-epoch ritual.
  • Inspect callbacks and input pipelines before blaming the framework.
  • Simplify the training loop to isolate where memory actually starts accumulating.

Course illustration
Course illustration

All Rights Reserved.