tensorflow
while_loop
gradients
backpropagation
machine_learning

Compute gradients for each time step of tf.while_loop

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

TensorFlow can backpropagate through tf.while_loop, but getting gradients for each time step requires one extra idea: you must keep the per-step quantities you care about instead of only the final result. Once those values are collected, GradientTape.jacobian is usually the cleanest way to obtain one gradient contribution per time step.

The difference between total gradient and per-step gradients

If you compute a total loss by summing all time-step losses, tape.gradient returns the gradient of that total. That is often what training needs, but it does not tell you how much each time step contributed.

For example:

  • 'tape.gradient(total_loss, w) gives one aggregated gradient'
  • 'tape.jacobian(step_losses, w) gives one gradient entry per time step'

That distinction is the key to understanding the problem.

Store per-step values with TensorArray

Inside tf.while_loop, you normally cannot append to a Python list. The graph-friendly container is tf.TensorArray.

The following example runs a simple recurrence and stores a scalar loss for each step:

python
1import tensorflow as tf
2
3w = tf.Variable(0.8, dtype=tf.float32)
4x0 = tf.constant(1.5, dtype=tf.float32)
5steps = 5
6
7with tf.GradientTape(persistent=True) as tape:
8    def cond(i, x, losses_ta):
9        return i < steps
10
11    def body(i, x, losses_ta):
12        x_next = x * w + 1.0
13        step_loss = x_next * x_next
14        losses_ta = losses_ta.write(i, step_loss)
15        return i + 1, x_next, losses_ta
16
17    init_ta = tf.TensorArray(dtype=tf.float32, size=steps)
18    _, final_x, losses_ta = tf.while_loop(
19        cond,
20        body,
21        loop_vars=(0, x0, init_ta)
22    )
23
24    step_losses = losses_ta.stack()
25    total_loss = tf.reduce_sum(step_losses)
26
27total_grad = tape.gradient(total_loss, w)
28per_step_grad = tape.jacobian(step_losses, w)
29
30print("step losses:", step_losses.numpy())
31print("total grad:", float(total_grad.numpy()))
32print("per-step grad:", per_step_grad.numpy())

Because w is scalar here, the Jacobian is a vector with one entry per time step.

Why jacobian works here

step_losses is a vector of shape (T,), while w is scalar. The Jacobian therefore has one derivative for each element:

d step_losses[t] / d w

If your parameter is a tensor rather than a scalar, the Jacobian shape becomes larger. For example, if w is a matrix, each time step produces a full gradient tensor for that matrix.

That can be powerful, but it can also be memory-heavy.

A parameter matrix example

The same idea works for a matrix parameter used repeatedly in the loop:

python
1import tensorflow as tf
2
3W = tf.Variable([[0.3, 0.1], [0.2, 0.4]], dtype=tf.float32)
4x0 = tf.constant([1.0, -1.0], dtype=tf.float32)
5steps = 4
6
7with tf.GradientTape() as tape:
8    def cond(i, x, ta):
9        return i < steps
10
11    def body(i, x, ta):
12        x_next = tf.linalg.matvec(W, x)
13        step_loss = tf.reduce_sum(tf.square(x_next))
14        ta = ta.write(i, step_loss)
15        return i + 1, x_next, ta
16
17    init_ta = tf.TensorArray(tf.float32, size=steps)
18    _, _, ta = tf.while_loop(cond, body, (0, x0, init_ta))
19    step_losses = ta.stack()
20
21per_step_grads = tape.jacobian(step_losses, W)
22print(per_step_grads.shape)

Here the shape is (steps, 2, 2), meaning one matrix-shaped gradient for each step.

When to use gradient instead

If you only need the final training update, tape.gradient is simpler and cheaper:

python
1with tf.GradientTape() as tape:
2    # build step_losses inside tf.while_loop
3    total_loss = tf.reduce_sum(step_losses)
4
5grad = tape.gradient(total_loss, W)

Use jacobian only when you truly need time-step-level analysis, debugging, attribution, or custom weighting by step.

Practical debugging tips

If gradients come back as None, check these first:

  1. the variable must be watched by the tape
  2. the variable must actually influence the stored losses
  3. the values written into TensorArray must remain tensors in the graph
  4. you should not detach or convert to NumPy inside the taped region

A simple sanity check is to compare:

python
tf.reduce_sum(per_step_grad)

against:

python
tape.gradient(total_loss, w)

For scalar w, they should match numerically up to floating-point error.

Common Pitfalls

The most common mistake is asking tape.gradient for a vector of step losses and expecting one gradient per step. gradient aggregates unless you compute the Jacobian explicitly. Another frequent issue is trying to use Python lists inside tf.while_loop instead of TensorArray, which breaks graph-friendly execution. Developers also sometimes read only the final loop state and forget to store the intermediate loss terms they later want to differentiate. Memory usage is another trap: per-step Jacobians can get large quickly when the parameter tensor is large or the sequence is long. Finally, mixing eager-mode Python side effects into the taped region often leads to gradients being None.

Summary

  • 'tf.while_loop supports backpropagation, but per-step gradients require storing per-step values.'
  • Use TensorArray to collect losses or activations inside the loop.
  • Use GradientTape.jacobian when you want one gradient contribution per time step.
  • Use GradientTape.gradient when you only need the total update for training.
  • Verify that stored tensors depend on the watched variables and stay inside the graph.
  • Expect memory costs to rise when you request Jacobians for long sequences or large parameter tensors.

Course illustration
Course illustration

All Rights Reserved.