TensorFlow
machine learning
variable restoration
programming
deep learning

Restore subset of variables in Tensorflow

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

Restoring only part of a TensorFlow checkpoint is a common need in transfer learning, fine-tuning, and model surgery. You might want to load the feature extractor from an old checkpoint while leaving a new classification head randomly initialized.

In modern TensorFlow, the cleanest approach is object-based checkpointing. Instead of restoring by raw variable name lists whenever possible, you restore the specific objects you want and intentionally leave the rest unmatched.

Object-Based Checkpoint Restore in TensorFlow 2

The main tools are tf.train.Checkpoint and Keras models or layers. If you build a checkpoint object that only contains the parts you want restored, TensorFlow restores only those tracked variables.

Example:

python
1import tensorflow as tf
2
3class Model(tf.keras.Model):
4    def __init__(self):
5        super().__init__()
6        self.backbone = tf.keras.Sequential([
7            tf.keras.layers.Dense(16, activation="relu"),
8            tf.keras.layers.Dense(8, activation="relu"),
9        ])
10        self.head = tf.keras.layers.Dense(3)
11
12    def call(self, x):
13        x = self.backbone(x)
14        return self.head(x)
15
16
17model = Model()
18_ = model(tf.zeros((1, 4)))
19
20ckpt = tf.train.Checkpoint(backbone=model.backbone)
21status = ckpt.restore("/tmp/my_checkpoint")
22status.expect_partial()

In this example, only the backbone is restored. The head remains newly initialized because it was not part of the checkpoint object.

Why expect_partial() Matters

When you intentionally restore only part of a checkpoint, TensorFlow can report unmatched objects. That is expected. Calling status.expect_partial() makes that intent explicit.

Use it when partial restoration is deliberate, not when you are ignoring a mismatch you do not understand. If you expect everything to match, use stronger status checks instead of silently accepting partial loading.

A Transfer-Learning Example

A typical use case is reusing a pretrained encoder while replacing the classifier layer:

python
1old_model = Model()
2_ = old_model(tf.zeros((1, 4)))
3
4full_ckpt = tf.train.Checkpoint(model=old_model)
5full_ckpt.write("/tmp/full_model_ckpt")
6
7new_model = Model()
8_ = new_model(tf.zeros((1, 4)))
9new_model.head = tf.keras.layers.Dense(5)
10_ = new_model(tf.zeros((1, 4)))
11
12partial_ckpt = tf.train.Checkpoint(backbone=new_model.backbone)
13partial_ckpt.restore("/tmp/full_model_ckpt").expect_partial()

The restored weights go into the shared backbone structure. The new head keeps its own initialization and can be trained for the new task.

Historical TensorFlow 1 Pattern

Older TensorFlow 1 code often used tf.train.Saver(var_list=...) to restore a named subset of variables. That still matters when reading legacy code, but it is not the preferred style for new TensorFlow 2 projects.

The TensorFlow 2 object-based approach is usually safer because it matches tracked objects instead of relying entirely on brittle variable-name conventions.

When Partial Restore Fails

Partial restoration still depends on compatibility. The variables must line up in shape and meaning. If you changed a layer width from 16 to 32, the old checkpoint values can no longer fit that variable.

That means partial restore is best for:

  • reusing unchanged submodules
  • loading a pretrained backbone into a new task-specific model
  • resuming only some tracked components such as the model but not the optimizer

It is not a magic tool for incompatible architectures.

Common Pitfalls

A common mistake is forgetting to build the model before restoring. Variables often do not exist until the model has been called at least once.

Another mistake is using expect_partial() to hide an accidental mismatch instead of an intentional one. If you are not deliberately skipping variables, investigate the mismatch.

A third issue is restoring by variable names in modern TensorFlow code when object-based checkpointing would be simpler and more robust.

Finally, remember that optimizer state is tracked separately. Restoring model weights alone does not automatically restore optimizer slots unless you include the optimizer in the checkpoint object.

Summary

  • In TensorFlow 2, restore a subset of variables by building a checkpoint object that contains only the parts you want.
  • Call the model once before restoring so the variables exist.
  • Use expect_partial() when partial restoration is intentional.
  • Partial restore is ideal for transfer learning and reused submodules.
  • It only works when the restored variables are still shape-compatible with the checkpoint.

Course illustration
Course illustration

All Rights Reserved.