TensorFlow
tf.cond
function parameters
conditional operations
programming techniques

How to pass parmeters to functions inside tf.cond in Tensorflow?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

tf.cond expects two zero-argument callables, not two function calls with explicit arguments. So the usual way to "pass parameters" is to capture the tensors from the surrounding scope with a lambda, nested function, or helper factory.

The Core Shape of tf.cond

The signature is conceptually:

python
tf.cond(pred, true_fn, false_fn)

Notice that true_fn and false_fn are functions, not already-executed results. That means this is wrong:

python
tf.cond(pred, true_fn=my_func(x), false_fn=other_func(x))

because my_func(x) runs immediately in Python instead of being passed as a callable.

The Normal Solution: Lambdas

The common pattern is to close over the needed tensors.

python
1import tensorflow as tf
2
3x = tf.constant(5)
4scale = tf.constant(2)
5
6result = tf.cond(
7    x > 0,
8    true_fn=lambda: x * scale,
9    false_fn=lambda: x - scale,
10)
11
12print(result.numpy())

The branch functions take no explicit arguments, but they still have access to x and scale from the outer scope.

Passing Multiple Tensors

The same idea works when the branch logic depends on several values.

python
1import tensorflow as tf
2
3a = tf.constant([1.0, 2.0, 3.0])
4b = tf.constant([0.5, 0.5, 0.5])
5use_add = tf.constant(True)
6
7out = tf.cond(
8    use_add,
9    true_fn=lambda: tf.add(a, b),
10    false_fn=lambda: tf.subtract(a, b),
11)
12
13print(out.numpy())

This is still parameterless from TensorFlow's point of view, but operationally the branch has access to everything it needs.

Helper Functions for Readability

When the branch logic gets too large for a lambda, create helper factories that return zero-argument functions.

python
1import tensorflow as tf
2
3def make_true_fn(x, w, bias):
4    def _true():
5        return tf.nn.relu(tf.matmul(x, w) + bias)
6    return _true
7
8def make_false_fn(x, w):
9    def _false():
10        return tf.matmul(x, w)
11    return _false
12
13x = tf.ones((2, 3))
14w = tf.ones((3, 4))
15bias = tf.zeros((4,))
16pred = tf.constant(False)
17
18y = tf.cond(pred, make_true_fn(x, w, bias), make_false_fn(x, w))
19print(y.shape)

This keeps complex branch logic readable without violating the tf.cond API.

tf.function and Graph Context

tf.cond matters most when you are working in graph-style TensorFlow, especially inside @tf.function.

python
1import tensorflow as tf
2
3@tf.function
4def compute(x):
5    return tf.cond(
6        tf.reduce_mean(x) > 0,
7        true_fn=lambda: x * 2.0,
8        false_fn=lambda: x * 0.5,
9    )
10
11print(compute(tf.constant([1.0, 2.0])).numpy())

In eager code, ordinary Python if statements are often enough. tf.cond becomes important when you want TensorFlow-traceable control flow with tensor predicates.

Branch Outputs Must Match

Passing arguments is only part of the story. Both branches also need to return compatible structures and dtypes.

Bad example:

  • true branch returns a scalar
  • false branch returns a vector

That mismatch causes tracing or execution problems. Keep both branches structurally aligned.

What Not to Do

Do not use Python side effects as if they were branch outputs. For example, mutating a Python list inside one branch is not a good substitute for returning a tensor result.

Do not assume tf.cond behaves exactly like Python if. The predicate is a tensor expression, and the branch functions are part of TensorFlow control flow, especially inside @tf.function.

If you need debugging, prefer tf.print or explicit returned values over Python-only side effects.

Common Pitfalls

Calling the branch function immediately instead of passing a callable is the most common mistake.

Returning incompatible tensor shapes or structures from the two branches causes confusing tracing failures.

Mixing Python values and tensors carelessly can make the control flow behave differently from what you expect.

Using tf.cond where a plain Python if would be clearer in eager-only code can make the code unnecessarily harder to read.

Summary

  • 'tf.cond expects zero-argument callables for the true and false branches.'
  • Pass needed tensors by capturing them in lambdas or nested functions.
  • Use helper factories when branch logic is too large for one lambda.
  • Keep both branch outputs compatible in structure and dtype.
  • Inside graph-style TensorFlow, think of tf.cond as TensorFlow control flow, not as ordinary Python if.

Course illustration
Course illustration

All Rights Reserved.