How to pass parmeters to functions inside tf.cond in Tensorflow?
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
tf.cond expects two zero-argument callables, not two function calls with explicit arguments. So the usual way to "pass parameters" is to capture the tensors from the surrounding scope with a lambda, nested function, or helper factory.
The Core Shape of tf.cond
The signature is conceptually:
Notice that true_fn and false_fn are functions, not already-executed results. That means this is wrong:
because my_func(x) runs immediately in Python instead of being passed as a callable.
The Normal Solution: Lambdas
The common pattern is to close over the needed tensors.
The branch functions take no explicit arguments, but they still have access to x and scale from the outer scope.
Passing Multiple Tensors
The same idea works when the branch logic depends on several values.
This is still parameterless from TensorFlow's point of view, but operationally the branch has access to everything it needs.
Helper Functions for Readability
When the branch logic gets too large for a lambda, create helper factories that return zero-argument functions.
This keeps complex branch logic readable without violating the tf.cond API.
tf.function and Graph Context
tf.cond matters most when you are working in graph-style TensorFlow, especially inside @tf.function.
In eager code, ordinary Python if statements are often enough. tf.cond becomes important when you want TensorFlow-traceable control flow with tensor predicates.
Branch Outputs Must Match
Passing arguments is only part of the story. Both branches also need to return compatible structures and dtypes.
Bad example:
- true branch returns a scalar
- false branch returns a vector
That mismatch causes tracing or execution problems. Keep both branches structurally aligned.
What Not to Do
Do not use Python side effects as if they were branch outputs. For example, mutating a Python list inside one branch is not a good substitute for returning a tensor result.
Do not assume tf.cond behaves exactly like Python if. The predicate is a tensor expression, and the branch functions are part of TensorFlow control flow, especially inside @tf.function.
If you need debugging, prefer tf.print or explicit returned values over Python-only side effects.
Common Pitfalls
Calling the branch function immediately instead of passing a callable is the most common mistake.
Returning incompatible tensor shapes or structures from the two branches causes confusing tracing failures.
Mixing Python values and tensors carelessly can make the control flow behave differently from what you expect.
Using tf.cond where a plain Python if would be clearer in eager-only code can make the code unnecessarily harder to read.
Summary
- '
tf.condexpects zero-argument callables for the true and false branches.' - Pass needed tensors by capturing them in lambdas or nested functions.
- Use helper factories when branch logic is too large for one lambda.
- Keep both branch outputs compatible in structure and dtype.
- Inside graph-style TensorFlow, think of
tf.condas TensorFlow control flow, not as ordinary Pythonif.

