algorithm
k-nearest neighbors
computational geometry
data structures
spatial analysis

Choose the closest k points from given n points

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Given n points in a d-dimensional space and a query point Q, finding the k closest points is a fundamental problem in computational geometry, machine learning (k-NN), recommendation systems, and spatial search. The naive approach computes all n distances and selects the k smallest, running in O(n log n) with sorting or O(n log k) with a max-heap. For repeated queries, spatial data structures like KD-trees and ball trees reduce query time to O(log n) on average.

Problem Definition

  • Input: n points P = {P1, P2, ..., Pn}, a query point Q, and an integer k
  • Output: The k points from P closest to Q
  • Distance metric: Typically Euclidean, but the approach generalizes to Manhattan, Cosine, and others

Method 1: Sort All Distances — O(n log n)

The simplest approach — compute all distances, sort, take the first k:

python
1import math
2
3def k_closest_sort(points, query, k):
4    def distance(p):
5        return math.sqrt(sum((a - b) ** 2 for a, b in zip(p, query)))
6
7    return sorted(points, key=distance)[:k]
8
9# Example
10points = [(1, 2), (3, 3), (5, 1), (0, 0), (4, 4)]
11query = (2, 2)
12k = 3
13
14result = k_closest_sort(points, query, k)
15print(result)  # [(1, 2), (3, 3), (0, 0)]

Time: O(n log n). Space: O(n). Simple but does unnecessary work when k << n.

Method 2: Max-Heap — O(n log k)

Use a max-heap of size k. For each point, push it onto the heap and pop the farthest if the heap exceeds size k:

python
1import heapq
2
3def k_closest_heap(points, query, k):
4    def distance(p):
5        return sum((a - b) ** 2 for a, b in zip(p, query))
6
7    # Python has min-heap, so negate distances for max-heap behavior
8    max_heap = []
9    for point in points:
10        dist = distance(point)
11        if len(max_heap) < k:
12            heapq.heappush(max_heap, (-dist, point))
13        elif dist < -max_heap[0][0]:
14            heapq.heapreplace(max_heap, (-dist, point))
15
16    return [point for _, point in max_heap]
17
18result = k_closest_heap(points, query, k)
19print(result)  # [(0, 0), (3, 3), (1, 2)]

Time: O(n log k). Space: O(k). Optimal when k << n because the heap stays small.

Method 3: Quickselect — O(n) Average

Quickselect (partitioning) finds the k-th smallest element in O(n) average time:

python
1import random
2
3def k_closest_quickselect(points, query, k):
4    def distance(p):
5        return sum((a - b) ** 2 for a, b in zip(p, query))
6
7    def partition(arr, lo, hi):
8        pivot_dist = distance(arr[hi])
9        i = lo
10        for j in range(lo, hi):
11            if distance(arr[j]) <= pivot_dist:
12                arr[i], arr[j] = arr[j], arr[i]
13                i += 1
14        arr[i], arr[hi] = arr[hi], arr[i]
15        return i
16
17    arr = points[:]
18    lo, hi = 0, len(arr) - 1
19
20    while lo <= hi:
21        # Randomize pivot
22        rand_idx = random.randint(lo, hi)
23        arr[rand_idx], arr[hi] = arr[hi], arr[rand_idx]
24
25        pivot = partition(arr, lo, hi)
26        if pivot == k:
27            break
28        elif pivot < k:
29            lo = pivot + 1
30        else:
31            hi = pivot - 1
32
33    return arr[:k]
34
35result = k_closest_quickselect(points, query, 3)
36print(result)  # 3 closest points (unordered)

Time: O(n) average, O(n^2) worst case. Space: O(1) (in-place). The result is unordered — sort the k results if needed.

Method 4: KD-Tree — O(log n) Per Query

For repeated queries on the same point set, build a KD-tree once in O(n log n), then query in O(log n):

python
1from scipy.spatial import KDTree
2import numpy as np
3
4points_array = np.array([(1, 2), (3, 3), (5, 1), (0, 0), (4, 4)])
5tree = KDTree(points_array)
6
7query = np.array([2, 2])
8distances, indices = tree.query(query, k=3)
9
10print("Closest points:", points_array[indices])
11print("Distances:", distances)

scikit-learn Implementation

python
1from sklearn.neighbors import NearestNeighbors
2import numpy as np
3
4points = np.array([(1, 2), (3, 3), (5, 1), (0, 0), (4, 4)])
5
6nn = NearestNeighbors(n_neighbors=3, algorithm='kd_tree')
7nn.fit(points)
8
9query = np.array([[2, 2]])
10distances, indices = nn.kneighbors(query)
11
12print("Indices:", indices[0])       # [0, 1, 3]
13print("Distances:", distances[0])   # [1.0, 1.414, 2.828]

Comparison of Methods

MethodBuild TimeQuery TimeBest When
SortO(n log n)Single query, small n
Max-HeapO(n log k)Single query, k << n
QuickselectO(n) avgSingle query, unordered result OK
KD-TreeO(n log n)O(log n) avgMany queries on same data
Ball TreeO(n log n)O(log n) avgHigh dimensions (d > 20)

Handling High Dimensions

KD-trees degrade to O(n) in high dimensions (d > 20). Alternatives:

python
1from sklearn.neighbors import NearestNeighbors
2
3# Ball tree — better for high dimensions
4nn = NearestNeighbors(n_neighbors=k, algorithm='ball_tree', metric='euclidean')
5
6# Brute force — guaranteed correct, O(n) per query
7nn = NearestNeighbors(n_neighbors=k, algorithm='brute')
8
9# Auto — scikit-learn picks the best algorithm
10nn = NearestNeighbors(n_neighbors=k, algorithm='auto')

For very large datasets (millions of points), use approximate nearest neighbors:

python
1# Using FAISS (Facebook AI Similarity Search)
2import faiss
3import numpy as np
4
5points = np.random.rand(1_000_000, 128).astype('float32')
6query = np.random.rand(1, 128).astype('float32')
7
8index = faiss.IndexFlatL2(128)  # Exact search
9index.add(points)
10
11distances, indices = index.search(query, k=10)

Common Pitfalls

  • Comparing squared vs actual distances: When only ranking points (not needing actual distances), skip the sqrt operation — squared distances preserve the ordering and are faster to compute.
  • KD-tree in high dimensions: KD-trees are efficient for d < 20 but degrade to brute-force speed in higher dimensions. Use ball trees, approximate methods (FAISS, Annoy), or locality-sensitive hashing instead.
  • Mutable input with quickselect: Quickselect rearranges the input array in-place. Copy the array first (points[:]) if you need to preserve the original order.
  • Tie-breaking: When multiple points have the same distance to the query, the choice of which k to return is arbitrary. Be consistent (e.g., use index as a tiebreaker) if determinism matters.
  • Distance metric mismatch: Euclidean distance works well for numeric features on similar scales. For text, use cosine similarity. For geographic coordinates, use Haversine distance. Choosing the wrong metric gives meaningless results.

Summary

  • For a single query: use a max-heap for O(n log k) or quickselect for O(n) average
  • For repeated queries on the same data: build a KD-tree for O(log n) per query
  • Use scipy.spatial.KDTree or sklearn.neighbors.NearestNeighbors for production implementations
  • Skip sqrt when only comparing distances — squared Euclidean is sufficient for ranking
  • For high dimensions (d > 20) or millions of points, use FAISS or approximate nearest neighbor libraries

Course illustration
Course illustration

All Rights Reserved.