TensorFlow
tf.cond
conditional operations
programming
machine learning

Confused by the behavior of tf.cond

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

tf.cond is TensorFlow’s graph-level conditional operator. Many developers get confused by it because its behavior depends on whether code is running eagerly or inside a traced graph such as tf.function. The short version is that tf.cond chooses one branch at execution time, but both branch functions still participate in graph construction and shape checking.

What tf.cond Actually Does

At a high level, tf.cond(pred, true_fn, false_fn) behaves like an if-else whose predicate is a tensor rather than a normal Python boolean.

python
1import tensorflow as tf
2
3x = tf.constant(5)
4y = tf.constant(3)
5
6result = tf.cond(
7    x > y,
8    lambda: x + 10,
9    lambda: y + 10,
10)
11
12print(result.numpy())

If x > y is true, TensorFlow returns the result of true_fn. Otherwise it returns the result of false_fn.

Why It Feels Different from Python if

A Python if is decided immediately by the interpreter. tf.cond is designed for tensor-based control flow inside TensorFlow execution. That matters most inside @tf.function, where TensorFlow builds a graph representation of your computation.

python
1import tensorflow as tf
2
3@tf.function
4def choose_value(x):
5    return tf.cond(
6        x > 0,
7        lambda: x * 2,
8        lambda: -x,
9    )
10
11print(choose_value(tf.constant(4)))
12print(choose_value(tf.constant(-3)))

In this form, the predicate is a tensor, so Python cannot decide the branch at trace time. TensorFlow must represent both possible branches in the graph.

Both Branches Are Traced, Not Both Executed

This is the source of most confusion. Developers observe branch code being built or validated and assume both branches executed. What actually happens is:

  1. TensorFlow traces both branch functions to build a valid graph.
  2. At runtime, only the selected branch result is used.

That means branch functions should be pure and TensorFlow-friendly. If you place Python side effects inside them, the results can be surprising.

python
1import tensorflow as tf
2
3@tf.function
4def demo(x):
5    return tf.cond(
6        x > 0,
7        lambda: tf.constant(1),
8        lambda: tf.constant(0),
9    )
10
11print(demo(tf.constant(2)))

This is predictable because the branches only produce tensors. It gets murkier if you try to mix in ordinary Python mutation or printing during tracing.

Eager Mode Versus tf.function

In TensorFlow 2, eager execution is enabled by default. In plain eager code, you often do not need tf.cond at all.

python
1import tensorflow as tf
2
3def eager_example(x):
4    if x.numpy() > 0:
5        return x * 2
6    return -x
7
8print(eager_example(tf.constant(5)))

That works in eager mode because x.numpy() gives Python a concrete value. But it does not generalize well to traced graph execution. If the function needs to run under tf.function, tensor control flow should stay in TensorFlow form.

Branch Output Rules

Both branches must return compatible structures. TensorFlow expects the branch outputs to line up in count, nesting, and compatible dtypes.

Bad example:

python
# true_fn returns one tensor, false_fn returns two

Correct example:

python
1import tensorflow as tf
2
3result = tf.cond(
4    tf.constant(True),
5    lambda: (tf.constant(1), tf.constant(2)),
6    lambda: (tf.constant(3), tf.constant(4)),
7)
8
9print(result)

If the branch signatures do not match, TensorFlow raises an error during tracing rather than letting the mismatch slip into runtime.

Prefer Python if When the Predicate Is Truly Python

If the condition is known in Python, keep it as a normal if. Using tf.cond for a plain Python boolean adds complexity without value.

python
1flag = True
2
3if flag:
4    value = 10
5else:
6    value = 20
7
8print(value)

Use tf.cond only when the condition is a tensor that must remain inside TensorFlow execution.

Practical Rule of Thumb

Use this simple mental model:

  • Python if for Python values
  • 'tf.cond for tensor predicates in graph-compatible code'

That rule removes most ambiguity. If you break it, you usually end up fighting tracing errors, graph surprises, or branch mismatch exceptions.

Common Pitfalls

The main mistake is expecting tf.cond to behave exactly like a Python if inside traced code. Another frequent problem is putting Python side effects inside branch functions and then being surprised by tracing behavior. Developers also run into branch mismatches when the two functions return different structures or dtypes. Finally, many TensorFlow 2 programs can use ordinary Python control flow in eager mode, so reaching for tf.cond too early makes code harder to read than necessary.

Summary

  • 'tf.cond is for tensor-based conditionals, especially inside tf.function.'
  • TensorFlow traces both branches when building the graph.
  • At execution time, only the selected branch result is used.
  • Both branches must return compatible output structures.
  • Use a normal Python if when the predicate is an ordinary Python value.

Course illustration
Course illustration

All Rights Reserved.