TensorFlow
Slurm Cluster
Machine Learning
Distributed Computing
HPC

Running TensorFlow on a Slurm Cluster?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Running TensorFlow on Slurm is mostly about reliable environment setup and resource orchestration, not model architecture. A model that works locally can fail immediately on a cluster if CUDA libraries, NCCL setup, or network assumptions are wrong. The safest approach is staged validation from single process to distributed training, with reproducible job scripts and strong logging.

Build a Reproducible Runtime

Never rely on whatever Python happens to be on a login node. Create an explicit runtime and activate it inside every job.

bash
1module load python/3.10
2python -m venv ~/envs/tf310
3source ~/envs/tf310/bin/activate
4python -m pip install --upgrade pip
5python -m pip install tensorflow

For GPU jobs, verify compatibility among TensorFlow build, NVIDIA driver, CUDA runtime, and cuDNN available on compute nodes. A quick check script prevents expensive failures later.

python
1import tensorflow as tf
2
3print("TensorFlow:", tf.__version__)
4print("Visible GPUs:", tf.config.list_physical_devices("GPU"))

Run this through sbatch, not only on the login node.

Start with a Correct Slurm Job Script

A clear sbatch script encodes assumptions in one place and makes runs repeatable.

bash
1#!/bin/bash
2#SBATCH --job-name=tf-single-gpu
3#SBATCH --partition=gpu
4#SBATCH --nodes=1
5#SBATCH --ntasks=1
6#SBATCH --cpus-per-task=8
7#SBATCH --gres=gpu:1
8#SBATCH --mem=32G
9#SBATCH --time=02:00:00
10#SBATCH --output=logs/%x-%j.out
11#SBATCH --error=logs/%x-%j.err
12
13module load python/3.10
14source ~/envs/tf310/bin/activate
15
16srun python train.py --epochs 5 --batch-size 256

Submit and inspect:

bash
sbatch train.slurm
squeue -u "$USER"
sacct -j <jobid> --format=JobID,State,Elapsed,MaxRSS,AllocTRES

Keep log files per job ID so failures can be traced after nodes are released.

Validate Single-Node Throughput First

Before distributed training, confirm a stable single-node baseline. Example training script:

python
1import numpy as np
2import tensorflow as tf
3
4x = np.random.rand(20000, 32).astype("float32")
5y = (x.sum(axis=1) > 16).astype("float32")
6
7dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(20000).batch(256).prefetch(tf.data.AUTOTUNE)
8
9model = tf.keras.Sequential([
10    tf.keras.layers.Input(shape=(32,)),
11    tf.keras.layers.Dense(128, activation="relu"),
12    tf.keras.layers.Dense(1, activation="sigmoid"),
13])
14
15model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
16model.fit(dataset, epochs=3, verbose=2)

Capture samples per second and epoch time. This baseline helps detect regressions when you scale out.

Multi-Node Strategy with TF_CONFIG

For distributed jobs, each task needs a deterministic role and host list. One common pattern is constructing TF_CONFIG from Slurm-provided hostnames.

bash
1HOSTS=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
2HOST_ARRAY=($HOSTS)
3CHIEF_HOST=${HOST_ARRAY[0]}
4PORT=12345
5TASK_INDEX=$SLURM_PROCID
6
7TF_CONFIG=$(cat <<JSON
8{
9  "cluster": {
10    "worker": [
11      "${HOST_ARRAY[0]}:${PORT}",
12      "${HOST_ARRAY[1]}:${PORT}"
13    ]
14  },
15  "task": {"type": "worker", "index": ${TASK_INDEX}}
16}
17JSON
18)
19
20export TF_CONFIG
21srun python distributed_train.py

In Python, select an appropriate strategy:

python
1import tensorflow as tf
2
3strategy = tf.distribute.MultiWorkerMirroredStrategy()
4with strategy.scope():
5    model = tf.keras.Sequential([
6        tf.keras.layers.Input(shape=(32,)),
7        tf.keras.layers.Dense(64, activation="relu"),
8        tf.keras.layers.Dense(1)
9    ])
10    model.compile(optimizer="adam", loss="mse")

Start with two workers and short runs before full-scale jobs.

Data and Checkpoint Management

Use shared storage for checkpoints and final artifacts, but prefer local scratch for temporary shards when available. Checkpointing should be periodic and resumable.

python
1ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
2    filepath="/shared/checkpoints/run-2026-03-04/ckpt-{epoch:02d}.keras",
3    save_best_only=False,
4    save_weights_only=False,
5)

On restart, detect latest checkpoint and continue training rather than starting from zero. This is essential on preemptible partitions.

Operational Observability

Include basic run metadata in logs at startup:

  • TensorFlow version,
  • detected devices,
  • global batch size,
  • effective learning rate,
  • checkpoint and dataset paths.

These fields cut debugging time dramatically when a run diverges or crashes after hours.

Common Pitfalls

  • Installing dependencies on login nodes and assuming compute nodes expose identical libraries.
  • Requesting GPUs but too few CPU cores, starving data loading and reducing utilization.
  • Skipping single-node validation and debugging distributed issues without a baseline.
  • Writing frequent checkpoints to slow network storage and causing training stalls.
  • Building dynamic multi-node configs without deterministic host and rank mapping.

Summary

  • Treat TensorFlow on Slurm as an environment and orchestration discipline first.
  • Use explicit virtual environments or containers in every job script.
  • Validate GPU visibility and baseline throughput before scaling.
  • Configure distributed roles deterministically with Slurm metadata.
  • Log key run metadata and checkpoint frequently so failures are recoverable.

Course illustration
Course illustration

All Rights Reserved.