pytorch
pretrained model
machine learning
partial model loading
neural networks

How can I load a partial pretrained pytorch model?

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Loading part of a pretrained PyTorch model is a common transfer-learning task when your new architecture overlaps with an older checkpoint but is not identical. The safest approach is to work with state_dict objects, copy only matching weights, and be explicit about which layers should be skipped.

Start With state_dict

In PyTorch, partial loading is easiest when you save and load state_dict dictionaries rather than entire model objects. A state_dict maps parameter names to tensors.

python
1import torch
2import torch.nn as nn
3
4
5class SmallNet(nn.Module):
6    def __init__(self):
7        super().__init__()
8        self.features = nn.Sequential(
9            nn.Linear(10, 20),
10            nn.ReLU(),
11        )
12        self.classifier = nn.Linear(20, 2)
13
14    def forward(self, x):
15        x = self.features(x)
16        return self.classifier(x)
17
18
19model = SmallNet()
20torch.save(model.state_dict(), "model.pt")

That format is portable and much easier to adapt when only some layers match.

Use strict=False When Some Layers Differ

If the new model still contains many of the old parameter names, you can often load the checkpoint directly with strict=False.

python
1new_model = SmallNet()
2checkpoint = torch.load("model.pt", map_location="cpu")
3
4missing, unexpected = new_model.load_state_dict(checkpoint, strict=False)
5
6print("missing:", missing)
7print("unexpected:", unexpected)

strict=False tells PyTorch not to fail when the checkpoint and model differ. That is useful when you replaced the classifier head, added layers, or removed layers.

Filter the Checkpoint for Matching Shapes

Sometimes strict=False is not enough, especially when parameter names exist in both models but tensor shapes no longer match. In that case, filter the checkpoint manually.

python
1current = new_model.state_dict()
2pretrained = torch.load("model.pt", map_location="cpu")
3
4compatible = {
5    key: value
6    for key, value in pretrained.items()
7    if key in current and value.shape == current[key].shape
8}
9
10current.update(compatible)
11new_model.load_state_dict(current)

This pattern is very common when the feature extractor is unchanged but the output layer has a different number of classes.

It also gives you a convenient place to log which parameters were reused. That is helpful during debugging, because partial loading problems are often discovered only after training starts behaving differently than expected.

Example: Replace the Classifier Head

Here is a small transfer-learning style example with torchvision:

python
1import torch
2import torch.nn as nn
3from torchvision import models
4
5old_model = models.resnet18(weights=None)
6torch.save(old_model.state_dict(), "resnet18_old.pt")
7
8new_model = models.resnet18(weights=None)
9new_model.fc = nn.Linear(new_model.fc.in_features, 5)
10
11checkpoint = torch.load("resnet18_old.pt", map_location="cpu")
12missing, unexpected = new_model.load_state_dict(checkpoint, strict=False)
13
14print("missing:", missing)
15print("unexpected:", unexpected)

In this case, the feature layers load, while the final fc layer is intentionally different.

Freeze What You Reused

After loading pretrained layers, you may want to freeze them before fine-tuning the rest of the model.

python
1for param in new_model.parameters():
2    param.requires_grad = False
3
4for param in new_model.fc.parameters():
5    param.requires_grad = True

This is a standard transfer-learning workflow: reuse the pretrained backbone, retrain only the new head, then optionally unfreeze more layers later.

If you later decide to unfreeze additional layers, do it intentionally and usually with a smaller learning rate. That helps preserve the pretrained features instead of immediately overwriting them.

Common Pitfalls

  • Saving entire model objects instead of state_dict files, which makes partial loading harder.
  • Assuming strict=False will fix shape mismatches automatically.
  • Forgetting to inspect the missing and unexpected keys after partial loading.
  • Loading a checkpoint whose parameter names changed because of wrappers such as DataParallel.
  • Reusing pretrained layers but forgetting to set requires_grad according to your fine-tuning plan.

Summary

  • Use state_dict files for flexible partial loading in PyTorch.
  • Try load_state_dict(..., strict=False) when some layers differ by design.
  • Filter checkpoint keys manually when names or tensor shapes do not line up cleanly.
  • Replacing the classifier head is a common partial-loading use case.
  • Always inspect loaded keys and decide explicitly which layers to freeze or train.

Course illustration
Course illustration