tensorflow
tf.while_loop
machine learning
deep learning
python

How to use tf.while_loop in tensorflow

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

tf.while_loop lets you express a loop whose stopping condition depends on tensor values instead of ordinary Python values. It is most useful when you need graph-compatible control flow, especially inside TensorFlow functions where a normal Python while loop is not flexible enough or would be harder to stage correctly.

Understand the Core Signature

The basic form is:

python
tf.while_loop(cond, body, loop_vars)
  • 'cond receives the current loop variables and returns a scalar boolean tensor'
  • 'body receives the same loop variables and returns their updated values'
  • 'loop_vars is the initial state'

Here is a simple example that sums integers from 0 through 4.

python
1import tensorflow as tf
2
3
4def sum_first_n(n):
5    def cond(i, total):
6        return i < n
7
8    def body(i, total):
9        return i + 1, total + i
10
11    _, result = tf.while_loop(
12        cond=cond,
13        body=body,
14        loop_vars=(tf.constant(0), tf.constant(0)),
15    )
16    return result
17
18
19print(sum_first_n(tf.constant(5)).numpy())  # 10

The loop variables are tensors, and every iteration returns the next version of those tensors.

Use It When the Loop Depends on Tensor Values

A Python loop runs immediately in Python. tf.while_loop builds TensorFlow control flow that can be traced into a graph. That matters when:

  • the number of iterations depends on a tensor
  • you want the loop staged inside @tf.function
  • you need graph execution or export-friendly behavior
python
1@tf.function
2def powers_of_two(limit):
3    values = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
4
5    def cond(i, current, out):
6        return current <= limit
7
8    def body(i, current, out):
9        out = out.write(i, current)
10        return i + 1, current * 2, out
11
12    _, _, out = tf.while_loop(
13        cond,
14        body,
15        loop_vars=(tf.constant(0), tf.constant(1), values),
16    )
17
18    return out.stack()
19
20
21print(powers_of_two(tf.constant(20)).numpy())

TensorArray is commonly used because tensors themselves are immutable and cannot simply be appended to inside the loop.

Handle Shapes Explicitly When They Change

TensorFlow expects loop variable shapes to stay compatible across iterations. If a shape may grow or vary, use shape_invariants.

python
1import tensorflow as tf
2
3
4def build_prefix(limit):
5    def cond(i, values):
6        return i < limit
7
8    def body(i, values):
9        values = tf.concat([values, [i]], axis=0)
10        return i + 1, values
11
12    _, values = tf.while_loop(
13        cond,
14        body,
15        loop_vars=(tf.constant(0), tf.constant([], dtype=tf.int32)),
16        shape_invariants=(
17            tf.TensorShape([]),
18            tf.TensorShape([None]),
19        ),
20    )
21    return values
22
23
24print(build_prefix(tf.constant(4)).numpy())

Without shape_invariants, TensorFlow may reject the loop because the size of values changes each iteration.

Prefer Simpler Alternatives When Possible

In TensorFlow 2, a regular Python loop inside @tf.function is often converted automatically by AutoGraph. That means you do not need tf.while_loop for every iterative pattern.

Use tf.while_loop when:

  • AutoGraph is not expressing the loop the way you need
  • you need direct control over loop variables and invariants
  • you are writing lower-level graph-oriented code

If an ordinary Python for loop over a known range works cleanly, it is often easier to read.

That readability point matters in production code. tf.while_loop is powerful, but it exposes low-level loop state explicitly, so it is best reserved for cases where that extra control is actually needed.

Common Pitfalls

  • Returning a different number of loop variables from body than you passed in through loop_vars.
  • Mutating Python lists inside the loop instead of using TensorArray or tensor-based state.
  • Forgetting shape_invariants when tensor shapes change across iterations.
  • Reaching for tf.while_loop when a normal Python loop inside @tf.function would be simpler.

Summary

  • 'tf.while_loop is for tensor-driven loops that need TensorFlow control flow.'
  • 'cond, body, and loop_vars define the loop state and stopping condition.'
  • Use TensorArray for values accumulated across iterations.
  • Add shape_invariants when loop variable shapes can grow or vary.
  • Prefer simpler Python loops unless you specifically need graph-level loop control.

Course illustration
Course illustration

All Rights Reserved.