Flattening
Nested Lists
Python
List Comprehension
Data Structures

Flatten an irregular arbitrarily nested list of lists

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Flattening an irregularly nested list means converting a structure like [1, [2, [3, 4], 5], [6, 7]] into [1, 2, 3, 4, 5, 6, 7]. The challenge is that nesting depth is unpredictable — some elements are values, some are lists, and those lists may contain more lists. A recursive generator is the cleanest Python solution. For regular (uniformly nested) lists, itertools.chain or list comprehensions work, but irregular nesting requires checking each element's type.

python
1from collections.abc import Iterable
2
3def flatten(lst):
4    for item in lst:
5        if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
6            yield from flatten(item)
7        else:
8            yield item
9
10nested = [1, [2, [3, 4], 5], [6, [7, [8, 9]]]]
11print(list(flatten(nested)))
12# [1, 2, 3, 4, 5, 6, 7, 8, 9]

yield from delegates to the recursive call, producing each element lazily. The str and bytes check prevents strings from being iterated character by character (since strings are iterable in Python).

Method 2: Recursive Function (Returns List)

python
1def flatten(lst):
2    result = []
3    for item in lst:
4        if isinstance(item, list):
5            result.extend(flatten(item))
6        else:
7            result.append(item)
8    return result
9
10nested = [1, [2, 3], [4, [5, [6]]]]
11print(flatten(nested))
12# [1, 2, 3, 4, 5, 6]

Simpler but builds the entire list in memory. The generator approach is better for large datasets.

Method 3: Iterative with Stack

Avoid recursion by using an explicit stack:

python
1def flatten(lst):
2    stack = list(reversed(lst))
3    result = []
4    while stack:
5        item = stack.pop()
6        if isinstance(item, list):
7            stack.extend(reversed(item))
8        else:
9            result.append(item)
10    return result
11
12nested = [[1, 2], [3, [4, 5]], 6]
13print(flatten(nested))
14# [1, 2, 3, 4, 5, 6]

reversed() ensures elements come out in the correct order when popping from the stack. This avoids Python's default recursion limit (1000 frames).

Method 4: Regular Nesting with itertools.chain

For uniformly nested lists (one level deep only):

python
1from itertools import chain
2
3nested = [[1, 2], [3, 4], [5, 6]]
4flat = list(chain.from_iterable(nested))
5print(flat)
6# [1, 2, 3, 4, 5, 6]
7
8# Equivalent list comprehension
9flat = [x for sublist in nested for x in sublist]

This does NOT work for irregular nesting — chain.from_iterable only flattens one level.

Method 5: Using functools.reduce

python
1from functools import reduce
2import operator
3
4# One level only
5nested = [[1, 2], [3, 4], [5, 6]]
6flat = reduce(operator.add, nested)
7print(flat)
8# [1, 2, 3, 4, 5, 6]

This is O(n^2) because each + creates a new list. Avoid for large inputs.

Method 6: NumPy (Numeric Arrays Only)

python
1import numpy as np
2
3# Only works for regular-shaped numeric arrays
4nested = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
5flat = np.array(nested).flatten().tolist()
6print(flat)
7# [1, 2, 3, 4, 5, 6, 7, 8, 9]

NumPy requires uniform shapes and numeric types. It cannot handle irregular nesting.

Handling Mixed Types

python
1from collections.abc import Iterable
2
3def flatten(lst):
4    for item in lst:
5        if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
6            yield from flatten(item)
7        else:
8            yield item
9
10# Mixed types: numbers, strings, tuples, nested lists
11data = [1, "hello", [2, (3, 4)], [[5], "world"]]
12print(list(flatten(data)))
13# [1, 'hello', 2, 3, 4, 5, 'world']

Strings remain intact because of the isinstance(item, (str, bytes)) guard. Tuples and other iterables are flattened.

Depth-Limited Flattening

Sometimes you only want to flatten N levels deep:

python
1def flatten(lst, depth=1):
2    for item in lst:
3        if isinstance(item, list) and depth > 0:
4            yield from flatten(item, depth - 1)
5        else:
6            yield item
7
8nested = [1, [2, [3, [4]]]]
9print(list(flatten(nested, depth=1)))
10# [1, 2, [3, [4]]]
11
12print(list(flatten(nested, depth=2)))
13# [1, 2, 3, [4]]

Common Pitfalls

  • Strings are iterable: Without the str/bytes guard, "hello" gets split into ['h', 'e', 'l', 'l', 'o'], and single characters recurse infinitely since "h" is also iterable.
  • Recursion depth: Python's default recursion limit is 1000. Deeply nested structures (100+ levels) cause RecursionError. Use the iterative stack approach for untrusted input.
  • Dictionaries are iterable: isinstance(dict, Iterable) is True. Flattening a dict yields its keys. Add dict to the exclusion check if dicts should remain intact.
  • reduce(operator.add) is O(n^2): Each concatenation copies the entire accumulated list. For 10,000 sublists, this creates 10,000 intermediate lists. Use chain.from_iterable or a generator instead.
  • Generators are single-use: flatten() with yield returns a generator that can only be consumed once. Wrap in list() if you need to iterate multiple times.

Summary

  • Use a recursive generator with yield from for irregular nesting — it handles arbitrary depth and mixed types
  • Guard against strings and bytes to prevent infinite recursion on character iteration
  • Use itertools.chain.from_iterable for simple one-level flattening
  • Use an iterative stack approach when recursion depth is a concern
  • NumPy's flatten() only works for regular-shaped numeric arrays
  • Avoid reduce(operator.add) for large lists due to quadratic time complexity

Course illustration
Course illustration

All Rights Reserved.