TensorFlow
conditional execution
machine learning
programming
AI

Conditional execution in TensorFlow

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

TensorFlow 2 runs eagerly by default, so simple conditionals often look like normal Python if statements. The interesting part is what happens when the condition depends on a tensor and the code is traced into a graph with tf.function.

Python if Versus tf.cond

There are two common situations:

  1. eager execution with ordinary Python values
  2. graph-traced execution where the predicate is a tensor

In eager mode, this is just Python:

python
1import tensorflow as tf
2
3def eager_branch(x):
4    if x > 0:
5        return x + 1
6    return x - 1
7
8print(eager_branch(3))

But if the predicate is a tensor inside a tf.function, TensorFlow needs graph control flow. In that case, AutoGraph usually converts the Python if into tf.cond.

Explicit tf.cond

tf.cond is the graph-level conditional operator. It chooses between two branch functions:

python
1import tensorflow as tf
2
3@tf.function
4def choose_branch(x):
5    return tf.cond(
6        x > 0,
7        lambda: x + 1,
8        lambda: x - 1,
9    )
10
11print(choose_branch(tf.constant(5)).numpy())
12print(choose_branch(tf.constant(-2)).numpy())

Both branch functions must return compatible structures and dtypes.

AutoGraph Usually Writes tf.cond for You

In modern TensorFlow, you often do not need to write tf.cond directly. This is equivalent in many graph-traced cases:

python
1@tf.function
2def choose_branch_autograph(x):
3    if x > 0:
4        return x + 1
5    else:
6        return x - 1

AutoGraph rewrites the tensor-dependent if into TensorFlow control flow.

This is one of the main reasons TensorFlow 2 code is easier to read than older graph-only TensorFlow 1 code.

tf.where Is Different

People often confuse tf.cond with tf.where, but they solve different problems.

  • 'tf.cond: choose between two branch computations'
  • 'tf.where: choose element-wise values from tensors'

Example of element-wise selection:

python
x = tf.constant([1, -2, 3, -4], dtype=tf.int32)
result = tf.where(x > 0, x, tf.zeros_like(x))
print(result.numpy())

This does not mean "run one branch or the other once." It means "for each element, choose one value or another."

Be Careful with Work Created Outside the Branches

This is a subtle but important point from the TensorFlow docs: operations created outside the branch functions can still execute regardless of which branch is selected.

For example, if you compute an intermediate tensor before tf.cond, that computation is already part of the graph:

python
1@tf.function
2def demo(x, y):
3    z = x * y
4    return tf.cond(
5        x < y,
6        lambda: x + z,
7        lambda: y * y,
8    )

The multiply that creates z is outside the branches, so it is not protected by the conditional itself.

When Conditional Execution Is Useful

Conditional execution appears in:

  • custom training logic
  • model routing and mixture-of-experts style patterns
  • loss functions with piecewise definitions
  • branch-specific tensor transformations

For purely Python-side orchestration, normal if statements are enough. TensorFlow control flow matters when the condition should remain part of the traced computation.

Side Effects and State

Be careful with side effects inside graph-traced branch functions. TensorFlow is designed around tensor outputs, not arbitrary Python mutation. If you need a value from a branch, return it from the branch function instead of relying on Python state that was mutated inside the branch.

That rule avoids a lot of confusing bugs.

Common Pitfalls

The biggest mistake is using a Python if on a tensor in a context where TensorFlow expects graph control flow and then being surprised by tracing behavior. Inside tf.function, let AutoGraph convert the conditional or use tf.cond explicitly.

Another mistake is using tf.where when you really want branch-level control flow. tf.where is for element-wise selection, not full branch execution.

People also assume everything outside the chosen tf.cond branch is skipped. Operations created outside the branch functions are not controlled by the branch selection.

Finally, make sure both branches return compatible tensor structures. Mismatched outputs will fail tracing.

Summary

  • In eager TensorFlow, ordinary Python if statements often work naturally.
  • Inside tf.function, tensor-dependent conditionals are typically converted to tf.cond.
  • Use tf.cond for branch-level graph control flow.
  • Use tf.where for element-wise selection between tensors.
  • Keep side effects out of branch logic and return needed values explicitly.

Course illustration
Course illustration

All Rights Reserved.