How can I load a partial pretrained pytorch model?
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
Loading part of a pretrained PyTorch model is a common transfer-learning task when your new architecture overlaps with an older checkpoint but is not identical. The safest approach is to work with state_dict objects, copy only matching weights, and be explicit about which layers should be skipped.
Start With state_dict
In PyTorch, partial loading is easiest when you save and load state_dict dictionaries rather than entire model objects. A state_dict maps parameter names to tensors.
That format is portable and much easier to adapt when only some layers match.
Use strict=False When Some Layers Differ
If the new model still contains many of the old parameter names, you can often load the checkpoint directly with strict=False.
strict=False tells PyTorch not to fail when the checkpoint and model differ. That is useful when you replaced the classifier head, added layers, or removed layers.
Filter the Checkpoint for Matching Shapes
Sometimes strict=False is not enough, especially when parameter names exist in both models but tensor shapes no longer match. In that case, filter the checkpoint manually.
This pattern is very common when the feature extractor is unchanged but the output layer has a different number of classes.
It also gives you a convenient place to log which parameters were reused. That is helpful during debugging, because partial loading problems are often discovered only after training starts behaving differently than expected.
Example: Replace the Classifier Head
Here is a small transfer-learning style example with torchvision:
In this case, the feature layers load, while the final fc layer is intentionally different.
Freeze What You Reused
After loading pretrained layers, you may want to freeze them before fine-tuning the rest of the model.
This is a standard transfer-learning workflow: reuse the pretrained backbone, retrain only the new head, then optionally unfreeze more layers later.
If you later decide to unfreeze additional layers, do it intentionally and usually with a smaller learning rate. That helps preserve the pretrained features instead of immediately overwriting them.
Common Pitfalls
- Saving entire model objects instead of
state_dictfiles, which makes partial loading harder. - Assuming
strict=Falsewill fix shape mismatches automatically. - Forgetting to inspect the missing and unexpected keys after partial loading.
- Loading a checkpoint whose parameter names changed because of wrappers such as
DataParallel. - Reusing pretrained layers but forgetting to set
requires_gradaccording to your fine-tuning plan.
Summary
- Use
state_dictfiles for flexible partial loading in PyTorch. - Try
load_state_dict(..., strict=False)when some layers differ by design. - Filter checkpoint keys manually when names or tensor shapes do not line up cleanly.
- Replacing the classifier head is a common partial-loading use case.
- Always inspect loaded keys and decide explicitly which layers to freeze or train.

