multi-GPU training
Keras
deep learning
parallel computing
machine learning

How to do multi GPU training with Keras?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Multi-GPU training in Keras is typically done with TensorFlow distribution strategies, especially tf.distribute.MirroredStrategy for one machine with multiple GPUs. The strategy replicates the model on each GPU and synchronizes gradients each step.

A good setup includes dataset sharding, global batch-size planning, and performance validation. This article shows the standard pattern and the checks you should run before long training jobs.

Core Sections

1. Detect GPUs and define strategy

python
1import tensorflow as tf
2
3print(tf.config.list_physical_devices('GPU'))
4strategy = tf.distribute.MirroredStrategy()
5print('replicas:', strategy.num_replicas_in_sync)

The number of replicas should match visible GPUs.

2. Build model inside strategy scope

python
1with strategy.scope():
2    model = tf.keras.Sequential([
3        tf.keras.layers.Input(shape=(224, 224, 3)),
4        tf.keras.layers.Conv2D(32, 3, activation='relu'),
5        tf.keras.layers.GlobalAveragePooling2D(),
6        tf.keras.layers.Dense(10, activation='softmax')
7    ])
8    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

Creating and compiling inside scope is required for proper variable placement.

3. Scale batch size and input pipeline

python
global_batch_size = 128
train_ds = train_ds.shuffle(10_000).batch(global_batch_size).prefetch(tf.data.AUTOTUNE)
model.fit(train_ds, epochs=10)

Tune global batch size to balance throughput and convergence stability.

4. Monitor performance and determinism

Track per-step time, GPU utilization, and validation metrics. Multi-GPU speedups are not guaranteed if your input pipeline is slow.

Use profiling tools (tf.profiler) to detect input bottlenecks before increasing model complexity.

5. Build a repeatable validation checklist

Once the implementation is in place, create a deterministic validation checklist for multi-GPU Keras training. At minimum, include one baseline scenario, one edge-case scenario, and one failure-path scenario with expected outcomes documented in plain language. This prevents knowledge from staying implicit and reduces the risk of regressions during dependency updates or refactors.

A useful checklist also captures runtime assumptions: framework versions, SDK versions, configuration flags, and environment variables required for a successful run. Many teams skip this because the setup seems obvious during initial development, but those hidden assumptions are usually what break first when code moves to CI, staging, or another developer machine.

text
1validation checklist
2- baseline case with expected output and key fields
3- edge case with constrained or unusual input
4- failure case with expected error handling behavior
5- recorded runtime and dependency assumptions

Keep this checklist versioned with code. If behavior changes, update the expected outputs in the same pull request so future debugging has an authoritative reference for what changed and why.

6. Operational hardening and maintenance

Long-term reliability for multi-GPU Keras training requires observability and explicit ownership. Add targeted logs and metrics around critical steps so incident responders can quickly identify whether failures come from input quality, environment drift, external service dependencies, or code regressions. Without these signals, most incident time is lost reconstructing context instead of fixing root causes.

Define maintenance routines for upgrades and compatibility checks. Libraries and platforms evolve continuously, and subtle behavior changes are common. Lightweight smoke tests should run regularly, not only during feature work, to catch drift before it reaches production.

bash
# example recurring check command
make smoke-test

Finally, document rollback criteria in advance. If a deployment changes multi-GPU Keras training behavior unexpectedly, teams should know when to roll back immediately versus when to hot-fix forward. This converts operational response from guesswork into a controlled process and improves overall system resilience.

Common Pitfalls

  • Building model outside distribution strategy scope.
  • Using per-replica batch size as if it were global batch size.
  • Ignoring input pipeline throughput and blaming GPUs for slow training.
  • Assuming linear speedup without measuring communication overhead.
  • Mixing incompatible CUDA/cuDNN/TensorFlow versions across nodes.

Summary

Keras multi-GPU training is straightforward with MirroredStrategy when model creation, batching, and data input are configured correctly. Treat performance as a system problem, not just a model problem. With proper profiling and batch-size tuning, you can achieve stable and meaningful training acceleration.


Course illustration
Course illustration

All Rights Reserved.