Vision Transformers

Topics Covered

From CNNs to Patches

Why We Needed Something Else

The Patch Tokenization Trick

The Patch Projection in Code

Patch Size as the Key Hyperparameter

Why This Was Not Obvious

The Receptive Field Argument, Reconsidered

Patch Tokens vs Conv Features: A Concrete Comparison

A Note on Computational Cost

ViT Architecture

The Full Forward Pass

The CLS Token Trick

Transformer Blocks: Pre-LN Edition

Scaling the Model: ViT-S, B, L, H

The MLP Head

Putting it Together

What the Attention Heads Actually Do

Initialization Matters More Than You Think

Layer Scale and Training Stability for Deep ViTs

The Role of the MLP Block

Position Embeddings for Images

The Original ViT Choice: Learned 1D Positions

Why 1D Works (Mostly)

The Variable Resolution Problem

2D Sinusoidal Embeddings

Relative Position Bias (Swin and Friends)

RoPE-2D and Rotary Position Encoding

Which Position Encoding Should You Use?

A Worked Example: Position Embedding Interpolation

The Special Case of Video

Position Encoding and Generalization

Position Embeddings and Attention Interpretability

When ViT Beats CNNs

The Data-Efficiency Crossover

DeiT: Making ViT Data-Efficient

Swin Transformer: Hierarchical Attention for Images

ConvNeXt: Modernizing the CNN

MAE-Pretrained ViT: Self-Supervised Scaling

Practical Guidance

The Bigger Lesson

A Quick Tour of ViT Descendants

ViT in Production: Practical Notes

Measuring ViT vs CNN in Practice

For almost a decade, the answer to "how do I process an image with a neural network?" was "use a CNN." Convolutions gave you translation equivariance, locality, and weight sharing for free. That inductive bias was so strong that it defined vision research from AlexNet (2012) through EfficientNet (2019). Then in 2020, a Google Brain paper called "An Image is Worth 16x16 Words" proposed something that looked almost like heresy: throw away the convolutions entirely, chop the image into flat patches, and feed them to a standard transformer encoder. If you had enough data, they claimed, you could match or beat a ResNet.

The claim landed harder than anyone expected. The Vision Transformer (ViT) is now the default backbone for most modern vision systems, from CLIP to DINOv2 to the vision towers inside every multimodal LLM. And the core trick is so simple that you can implement it in ten lines of PyTorch. The hard part is understanding why it works, and under what conditions.

Why We Needed Something Else

CNNs were not failing. They were winning. But they had three quiet problems that a few research groups kept poking at.

The first problem was that convolutions only see local windows. A 3×33 \times 3 filter sees nine pixels. To see all 224×224224 \times 224 pixels you need a deep stack of layers that progressively widens the receptive field through downsampling and dilation. This works, but it means information from opposite corners of the image only meets at the top of the network. Early layers never see the big picture. Self-attention, by contrast, is global from the very first layer.

The second problem was architectural. CNNs for vision and transformers for language were two completely separate families. Every multimodal system had to bolt them together with projection layers and hope the representations were compatible. If you could process images with a transformer too, the whole ML stack could standardize on one architecture. That is, in retrospect, the single biggest engineering win ViT delivered.

The third problem was scaling. CNNs scale, but not as cleanly as transformers do. Transformers have a known recipe: more data, more parameters, more compute, better results, predictable curves. CNNs had diminishing returns and lots of task-specific tuning. When you are trying to pretrain a foundation model on a billion images, you want the architecture that scales most predictably.

Key Insight

The real contribution of ViT was not 'transformers beat CNNs.' It was 'image modeling can be unified with language modeling.' Once you accept that an image is just a sequence of patch tokens, every technique from the LLM world (BERT-style masking, contrastive pretraining, instruction tuning, mixture-of-experts) transfers to vision with almost no change. That architectural unification is what enabled multimodal models to exist at all.

The Patch Tokenization Trick

Here is the entire conceptual leap. An image is H×W×CH \times W \times C pixels. A transformer expects a sequence of token embeddings, each a DD-dimensional vector. How do you bridge the two?

ViT's answer: divide the image into non-overlapping patches of size P×PP \times P (typically 16x16). A 224x224 image gives you (224/16)2=196(224 / 16)^2 = 196 patches. Flatten each patch into a vector of length PPC=16163=768P \cdot P \cdot C = 16 \cdot 16 \cdot 3 = 768. Project that vector through a single linear layer to get a DD-dimensional embedding. Those 196 embeddings are your token sequence. Add positional encodings. Feed to a transformer.

That is it. That is the whole trick. The image becomes a 196-length sequence of 768-dimensional tokens, and from that point on, the rest of the architecture is identical to BERT.

Image → patches → linear projection → token sequence
project224×224 image(P=16, 14×14=196 patches in real ViT)flatteneach patchP·P·C = 16·16·3 = 768 dimsELinear(768 → D)t₁t₂t₃t₄t₅t₆t₇t₈t₉token sequence(D-dim per token)+ position embeddings → feed to transformer
Divide the image into non-overlapping P×P patches, flatten each into a P²·C-dim vector, then project to D dims through a single linear layer. The result is a sequence of D-dim tokens — exactly what a transformer expects.

Why does this work? Because self-attention does not care about the spatial arrangement of its inputs. It treats the sequence as a set (with positional encodings adding back just enough order information). The "2D-ness" of an image is handled entirely by the positional embeddings, not by the attention mechanism itself. This is the freedom that lets you unify image and text processing.

The Patch Projection in Code

xpi=Eflatten(patchi)+pix_p^i = E \cdot \text{flatten}(\text{patch}_i) + p_i
python
1import torch
2import torch.nn as nn
3
4B, C, H, W = 1, 3, 224, 224
5P = 16  # patch size
6D = 768  # embedding dimension
7img = torch.randn(B, C, H, W)
8
9# Reshape into patches: (B, C, H, W) -> (B, n_patches, C*P*P)
10patches = img.unfold(2, P, P).unfold(3, P, P)
11patches = patches.contiguous().view(B, C, -1, P, P)
12patches = patches.permute(0, 2, 1, 3, 4).flatten(2)
13print(patches.shape)  # (B, n_patches=196, 768)
14
15proj = nn.Linear(C * P * P, D)
16tokens = proj(patches)
17print(tokens.shape)  # (B, 196, 768)
ViT turns an image into a sequence of patch embeddings, then runs standard transformer blocks on it. Nothing about the transformer itself is vision-specific.

An equivalent and more efficient implementation uses a single convolution with kernel size and stride both equal to PP. This does the patching and the projection in one operation: nn.Conv2d(in_channels=3, out_channels=768, kernel_size=16, stride=16). The convolution here is not doing convolutional feature detection. It is just an efficient implementation of "chop into patches and linearly project each one." The weight-sharing property of the conv layer aligns perfectly with applying the same projection matrix to every patch.

This is often a point of confusion: ViT does use a convolutional layer at the very start. But it is a non-overlapping conv with stride equal to kernel size, which is mathematically identical to the flatten-and-project formulation. No spatial information leaks between patches during this step. After this layer, there are no more convolutions anywhere in the model.

Patch Size as the Key Hyperparameter

Patch size PP is the one hyperparameter that changes everything. Smaller patches mean more tokens, which mean more compute (attention is O(N2)O(N^2) in sequence length) but finer-grained representations.

  • ViT-B/16: patches of 16x16, 196 tokens for a 224x224 image
  • ViT-B/32: patches of 32x32, 49 tokens for the same image (4x faster, less detail)
  • ViT-B/8: patches of 8x8, 784 tokens (4x more compute, much finer granularity)

The 16x16 default is a compromise. It keeps the sequence length manageable while still giving enough resolution for standard classification tasks. For dense prediction tasks (segmentation, detection), smaller patches or hierarchical architectures (like Swin) are usually needed because you need pixel-level output.

The choice of patch size also changes what a "token" means. With 16x16 patches, each token represents roughly a small object part (an eye, a wheel, a letter on a sign). With 32x32 patches, a token represents something like half a face or a whole license plate. The model's representations are built on top of these patch units, so the granularity of the patch determines the granularity of what the model can reason about directly.

Interview Tip

If your task needs fine spatial detail, do not immediately reach for smaller patches. Instead, consider hierarchical architectures like Swin Transformer that start with small patches and progressively merge them as you go deeper. This gives you the fine detail of small patches at early layers and the computational efficiency of large patches at later layers.

Why This Was Not Obvious

In hindsight, patch tokenization looks trivial. But from a 2019 perspective, it was genuinely surprising that it worked at all. Everyone expected that vision models needed explicit spatial inductive biases (locality, translation equivariance) to be efficient. Throwing those away felt like asking the model to re-learn basic facts about images from scratch using billions of examples.

And in fact, that is what ViT does. It rediscovers locality, translation equivariance, and hierarchical feature extraction purely from data. Probing studies on trained ViTs show attention heads that behave like edge detectors in early layers, attention patterns that follow object boundaries in middle layers, and increasingly abstract representations in later layers, the same rough hierarchy you see in CNNs. The model is not avoiding these concepts; it is learning them instead of having them hardcoded.

This is why ViT needs more data than a CNN to reach the same accuracy on ImageNet. A CNN starts with useful inductive biases baked in. ViT starts from scratch and has to learn them. When the dataset is small, the CNN wins. When the dataset is large enough for the ViT to learn the right biases, the ViT wins, and by larger and larger margins as you keep scaling. We will come back to this crossover point later in the lesson.

The Receptive Field Argument, Reconsidered

One way to think about why global attention matters is to compare receptive fields. A ResNet-50 has a theoretical receptive field covering the entire input image at its deepest layer, but the effective receptive field (measured by gradient-based attribution) is much smaller, typically around a quarter of the input image width. Most of the "action" in a CNN is local, because each conv filter only looks at a 3x3 window of the previous layer, and information has to travel up the stack layer by layer to propagate globally.

ViT does not have this issue. From the very first transformer block, every token attends to every other token. A patch in the top-left of the image can directly contribute to the output of a patch in the bottom-right. This is not just a theoretical difference. Probing studies on trained ViTs show that some attention heads in the first layer already compute globally distributed patterns, picking up information from across the entire image before any deeper processing happens.

For tasks where global reasoning matters (counting objects, reasoning about scene layout, comparing distant regions), this head start can be significant. For tasks that are mostly local (classifying a centered object, detecting a texture, segmenting at pixel granularity), the difference is smaller. The practical value of global attention depends on the task, and the architectural advantage of ViT is strongest for tasks that benefit from it.

Patch Tokens vs Conv Features: A Concrete Comparison

It helps to see the patch-token view side by side with the conv-feature view to appreciate how different they are as representations.

In a ResNet, the first layer applies 64 filters of size 7x7 with stride 2 to the input image, producing a 112x112x64 feature map. Each spatial position in that feature map is the output of one filter applied to a 7x7 window of the input, which overlaps with neighboring windows. The information is dense and redundant: each output value depends on pixels that also contribute to the neighbors.

In a ViT, the first layer applies a 16x16 conv with stride 16 (or equivalently, flattens and projects each 16x16 patch), producing a 14x14x768 sequence. There is no overlap between patches. Each token represents a completely disjoint region of the input, with a much richer 768-dimensional representation than the 64 channels the ResNet starts with.

The tradeoff is illuminating. The ResNet has many low-dimensional, overlapping feature points; the ViT has few high-dimensional, non-overlapping patch tokens. Neither representation is strictly better. The ResNet captures fine-grained spatial variation but wastes parameters on redundant local computation. The ViT captures less spatial detail but represents each patch with enough dimensions to encode complex patterns directly.

This difference ripples through the entire architecture. A ResNet processes its 112x112 grid through multiple downsampling stages (112 → 56 → 28 → 14 → 7), building up from fine to coarse features. A ViT processes its 14x14 grid at the same resolution throughout, with the representation getting more abstract in the channel dimension but staying at the same spatial resolution. The ResNet is spatially hierarchical; the ViT is not. This is why ViT is awkward for dense prediction tasks that need high-resolution output, and why hierarchical ViT variants like Swin were introduced.

A Note on Computational Cost

It is worth getting the computational cost right because the intuition can mislead. Self-attention is O(N2D)O(N^2 D) in sequence length N and hidden dim D. For a ViT-B/16 with 196+1 = 197 tokens and D = 768, each attention layer performs about 197276830197^2 \cdot 768 \approx 30 million multiply-adds per head, times 12 heads, times 12 layers, giving around 4.3 billion MACs for the whole attention stack. That is roughly the same order of magnitude as a ResNet-50 on a 224x224 input (4.1 billion MACs). ViT is not meaningfully more expensive than a CNN at this standard configuration, which is another reason it was adopted so readily. The perceived "ViT is expensive" reputation comes from larger variants (ViT-L, ViT-H) or from attempts to run ViT at higher resolution, where the quadratic cost in patch count does become a problem.