tf.cond
TensorFlow
batch processing
machine learning
deep learning

How to use tf.cond for batch processing

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

tf.cond is for choosing between two branches based on one scalar boolean condition. That detail matters because many people try to use it for per-example branching across a batch, and that is usually the wrong tool.

What tf.cond Actually Does

tf.cond(pred, true_fn, false_fn) evaluates one of two callables depending on the value of pred. The important part is that pred must be a scalar boolean tensor, not a vector of one boolean per batch item.

A correct use looks like this:

python
1import tensorflow as tf
2
3x = tf.constant([1.0, 2.0, 3.0])
4use_double = tf.constant(True)
5
6y = tf.cond(
7    use_double,
8    lambda: x * 2.0,
9    lambda: x + 1.0,
10)
11
12print(y)

Here the entire tensor goes through one branch or the other.

Why It Fails for Per-Element Batch Logic

Suppose you have a batch and want to do one operation for elements above zero and another for elements below zero. A vector condition such as x > 0 does not belong in tf.cond because that is a batch of booleans, not one scalar decision.

For per-element branching, use tf.where:

python
1x = tf.constant([-2.0, -1.0, 0.5, 3.0])
2y = tf.where(x > 0, x * 2.0, x - 1.0)
3
4print(y)

This applies the decision element-wise across the batch.

Batch-Level Condition vs Element-Level Condition

This is the key distinction:

  • use tf.cond when one condition controls the whole batch or whole computation path
  • use tf.where when each element in the batch needs its own choice

Examples where tf.cond makes sense:

  • enable or disable data augmentation for the whole batch
  • switch between training and inference behavior in a custom graph block
  • choose one expensive preprocessing branch based on a global flag

Examples where tf.where makes more sense:

  • clamp or transform individual values conditionally
  • apply one formula to positive elements and another to negative elements
  • mask batch elements independently

Example: Batch-Level Augmentation Switch

A realistic tf.cond use in a training pipeline might be:

python
1@tf.function
2def preprocess(batch, training):
3    return tf.cond(
4        training,
5        lambda: batch + tf.random.normal(tf.shape(batch), stddev=0.1),
6        lambda: batch,
7    )

Here training is a scalar boolean controlling the whole batch path.

When tf.map_fn Enters the Picture

Sometimes the logic is per-example but too complex for a simple tf.where. In that case, another option is tf.map_fn, which applies a function across batch elements.

python
1@tf.function
2def transform_one(item):
3    return tf.cond(
4        tf.reduce_mean(item) > 0,
5        lambda: item * 2.0,
6        lambda: item - 1.0,
7    )
8
9result = tf.map_fn(transform_one, batch)

This works because each call to transform_one uses a scalar condition for one batch item. It is more flexible, but also often slower than vectorized operations.

Prefer Vectorization When Possible

TensorFlow performs best when the operation is expressed in vectorized form. If tf.where can solve the problem, it is usually preferable to a tf.map_fn loop with nested tf.cond.

That leads to faster graphs and simpler code.

Common Pitfalls

The biggest mistake is passing a batch-shaped boolean tensor into tf.cond and expecting element-wise branching. tf.cond does not work that way.

Another issue is using tf.map_fn immediately even when a vectorized tf.where solution exists. That adds unnecessary overhead.

Developers also sometimes forget that both branches of tf.cond must return compatible structures and dtypes.

Finally, do not use Python if inside code that must become a TensorFlow graph unless you understand how AutoGraph will convert it. Be explicit when control flow matters.

Summary

  • 'tf.cond is for one scalar boolean that controls an entire branch of computation.'
  • It is not the right tool for element-wise conditional logic across a batch.
  • Use tf.where for vectorized per-element branching.
  • Use tf.map_fn only when each example needs more complex branch logic.
  • Prefer vectorized TensorFlow operations whenever they express the behavior you need.

Course illustration
Course illustration

All Rights Reserved.