tensorflow
tf.gather
last dimension
tensor manipulation
machine learning

In Tensorflow, how to use tf.gather for the last dimension?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

tf.gather is one of TensorFlow's most useful indexing operations, but many examples only show its default behavior on axis 0. In real models, you often need to pick values from the last dimension, such as selecting logits, channels, or embedding components. The key is to set axis=-1 and understand how the output shape changes.

Gathering From the Last Axis

By default, tf.gather(params, indices) reads from the first axis. If the values you want live in the final dimension, pass axis=-1. Negative axes work the same way they do in NumPy: -1 means the last axis, -2 means the second-to-last axis, and so on.

Here is a simple example with a rank-3 tensor:

python
1import tensorflow as tf
2
3tensor = tf.constant(
4    [
5        [[10, 11, 12, 13], [20, 21, 22, 23]],
6        [[30, 31, 32, 33], [40, 41, 42, 43]],
7    ],
8    dtype=tf.int32,
9)
10
11result = tf.gather(tensor, indices=[1, 3], axis=-1)
12
13print(result)
14print(result.shape)

Output:

python
1tf.Tensor(
2[[[11 13]
3  [21 23]]
4
5 [[31 33]
6  [41 43]]], shape=(2, 2, 2), dtype=int32)

TensorFlow replaced the last dimension of size 4 with the length of indices, which is 2. That shape rule is the main thing to remember: the chosen axis is removed and replaced by the shape of the index tensor.

This is useful when you want a subset of channels from an image tensor or a subset of class scores from a model output.

Selecting One Value Per Row With batch_dims

A common source of confusion is per-example selection. Suppose each row has its own index, and you want one value from the last axis for each row. If you use plain tf.gather, TensorFlow treats indices as global selection indices, not one index per batch item.

Use batch_dims=1 when the leading dimension is a batch dimension shared by both params and indices:

python
1import tensorflow as tf
2
3logits = tf.constant(
4    [
5        [0.1, 0.7, 0.2],
6        [0.6, 0.3, 0.1],
7        [0.2, 0.2, 0.6],
8    ],
9    dtype=tf.float32,
10)
11
12class_ids = tf.constant([1, 0, 2], dtype=tf.int32)
13
14picked = tf.gather(logits, class_ids, axis=1, batch_dims=1)
15
16print(picked)

Output:

python
tf.Tensor([0.7 0.6 0.6], shape=(3,), dtype=float32)

Without batch_dims=1, you would get a different shape and the wrong semantics. With batch_dims=1, TensorFlow treats each row independently and gathers along the last axis inside each batch item.

When To Use tf.gather_nd Instead

If you need indexing across multiple axes at once, tf.gather is no longer the best tool. For example, selecting a different row and column pair for each item is better handled with tf.gather_nd.

python
1import tensorflow as tf
2
3matrix = tf.constant(
4    [
5        [5, 6, 7],
6        [8, 9, 10],
7        [11, 12, 13],
8    ],
9    dtype=tf.int32,
10)
11
12indices = tf.constant([[0, 2], [2, 1]], dtype=tf.int32)
13
14result = tf.gather_nd(matrix, indices)
15print(result)

Output:

python
tf.Tensor([ 7 12], shape=(2,), dtype=int32)

Use tf.gather when one axis is being indexed. Use tf.gather_nd when each selection needs coordinates across several axes.

Common Pitfalls

The most common mistake is forgetting to set axis=-1. If you omit it, TensorFlow gathers from axis 0, which usually returns a tensor with the wrong shape.

Another issue is mixing up shared indices and per-batch indices. If every example in a batch needs a different index, add the correct batch_dims value instead of expecting tf.gather to infer it.

Index type also matters. TensorFlow expects integer indices, typically int32 or int64. If indices are floating-point values, the operation will fail.

Finally, shape surprises are normal when you first use gather. Print result.shape during development so you can confirm that the indexed axis has been replaced in the way you expect.

Summary

  • Use axis=-1 to gather from the last dimension.
  • The indexed axis is replaced by the shape of indices.
  • Use batch_dims when each batch item has its own index values.
  • Reach for tf.gather_nd when selection spans multiple axes.
  • Validate shapes early to avoid silent indexing mistakes.

Course illustration
Course illustration

All Rights Reserved.