How to use tf.cond for batch processing
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
tf.cond is for choosing between two branches based on one scalar boolean condition. That detail matters because many people try to use it for per-example branching across a batch, and that is usually the wrong tool.
What tf.cond Actually Does
tf.cond(pred, true_fn, false_fn) evaluates one of two callables depending on the value of pred. The important part is that pred must be a scalar boolean tensor, not a vector of one boolean per batch item.
A correct use looks like this:
Here the entire tensor goes through one branch or the other.
Why It Fails for Per-Element Batch Logic
Suppose you have a batch and want to do one operation for elements above zero and another for elements below zero. A vector condition such as x > 0 does not belong in tf.cond because that is a batch of booleans, not one scalar decision.
For per-element branching, use tf.where:
This applies the decision element-wise across the batch.
Batch-Level Condition vs Element-Level Condition
This is the key distinction:
- use
tf.condwhen one condition controls the whole batch or whole computation path - use
tf.wherewhen each element in the batch needs its own choice
Examples where tf.cond makes sense:
- enable or disable data augmentation for the whole batch
- switch between training and inference behavior in a custom graph block
- choose one expensive preprocessing branch based on a global flag
Examples where tf.where makes more sense:
- clamp or transform individual values conditionally
- apply one formula to positive elements and another to negative elements
- mask batch elements independently
Example: Batch-Level Augmentation Switch
A realistic tf.cond use in a training pipeline might be:
Here training is a scalar boolean controlling the whole batch path.
When tf.map_fn Enters the Picture
Sometimes the logic is per-example but too complex for a simple tf.where. In that case, another option is tf.map_fn, which applies a function across batch elements.
This works because each call to transform_one uses a scalar condition for one batch item. It is more flexible, but also often slower than vectorized operations.
Prefer Vectorization When Possible
TensorFlow performs best when the operation is expressed in vectorized form. If tf.where can solve the problem, it is usually preferable to a tf.map_fn loop with nested tf.cond.
That leads to faster graphs and simpler code.
Common Pitfalls
The biggest mistake is passing a batch-shaped boolean tensor into tf.cond and expecting element-wise branching. tf.cond does not work that way.
Another issue is using tf.map_fn immediately even when a vectorized tf.where solution exists. That adds unnecessary overhead.
Developers also sometimes forget that both branches of tf.cond must return compatible structures and dtypes.
Finally, do not use Python if inside code that must become a TensorFlow graph unless you understand how AutoGraph will convert it. Be explicit when control flow matters.
Summary
- '
tf.condis for one scalar boolean that controls an entire branch of computation.' - It is not the right tool for element-wise conditional logic across a batch.
- Use
tf.wherefor vectorized per-element branching. - Use
tf.map_fnonly when each example needs more complex branch logic. - Prefer vectorized TensorFlow operations whenever they express the behavior you need.

