TensorFlow
Keras
Model Summary
Neural Networks
Machine Learning

Display Tensorflow Model Summary as like in Keras

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

In TensorFlow 2.x, model.summary() works directly on Keras models (Sequential, Functional, or subclassed). For TensorFlow 1.x models or raw tf.function graphs, you can use tf.debugging.set_log_device_placement, tf.profiler, or manually iterate over tf.trainable_variables() to display a model summary. The torchsummary-style output can be achieved with the tf.keras.utils.plot_model() function for visual summaries.

Keras model.summary() (Standard Approach)

python
1import tensorflow as tf
2
3# Sequential model
4model = tf.keras.Sequential([
5    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
6    tf.keras.layers.Dropout(0.2),
7    tf.keras.layers.Dense(64, activation='relu'),
8    tf.keras.layers.Dense(10, activation='softmax'),
9])
10
11model.summary()

Output:

 
1Model: "sequential"
2_________________________________________________________________
3 Layer (type)                Output Shape              Param #
4=================================================================
5 dense (Dense)               (None, 128)               100480
6 dropout (Dropout)           (None, 128)               0
7 dense_1 (Dense)             (None, 64)                8256
8 dense_2 (Dense)             (None, 10)                650
9=================================================================
10Total params: 109386 (427.29 KB)
11Trainable params: 109386 (427.29 KB)
12Non-trainable params: 0 (0.00 Byte)

Functional API Summary

python
1import tensorflow as tf
2
3inputs = tf.keras.Input(shape=(224, 224, 3))
4x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
5x = tf.keras.layers.MaxPooling2D()(x)
6x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
7x = tf.keras.layers.GlobalAveragePooling2D()(x)
8x = tf.keras.layers.Dense(128, activation='relu')(x)
9outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
10
11model = tf.keras.Model(inputs=inputs, outputs=outputs, name='my_cnn')
12model.summary()
13
14# Show with more detail (including layer connectivity)
15model.summary(show_trainable=True, expand_nested=True)

Subclassed Model Summary

Subclassed models require calling the model once with sample input before summary() works:

python
1import tensorflow as tf
2
3class MyModel(tf.keras.Model):
4    def __init__(self):
5        super().__init__()
6        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
7        self.dense2 = tf.keras.layers.Dense(64, activation='relu')
8        self.output_layer = tf.keras.layers.Dense(10, activation='softmax')
9
10    def call(self, inputs):
11        x = self.dense1(inputs)
12        x = self.dense2(x)
13        return self.output_layer(x)
14
15model = MyModel()
16
17# This fails before the model is built:
18# model.summary()  # ValueError: This model has not yet been built
19
20# Option 1: Call with sample input
21model(tf.zeros((1, 784)))
22model.summary()
23
24# Option 2: Explicitly build
25model.build(input_shape=(None, 784))
26model.summary()

Visual Model Plot

python
1import tensorflow as tf
2
3model = tf.keras.Sequential([
4    tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
5    tf.keras.layers.MaxPooling2D(),
6    tf.keras.layers.Flatten(),
7    tf.keras.layers.Dense(128, activation='relu'),
8    tf.keras.layers.Dense(10, activation='softmax'),
9])
10
11# Save model architecture as an image
12tf.keras.utils.plot_model(
13    model,
14    to_file='model_architecture.png',
15    show_shapes=True,          # Show input/output shapes
16    show_layer_names=True,     # Show layer names
17    show_layer_activations=True,  # Show activation functions
18    show_dtype=False,
19    rankdir='TB',              # Top to bottom (or 'LR' for left to right)
20    dpi=150,
21)

Requires graphviz and pydot:

bash
1pip install pydot graphviz
2# Also install system graphviz:
3# Ubuntu: sudo apt install graphviz
4# macOS: brew install graphviz

Custom Summary Function

python
1import tensorflow as tf
2
3def detailed_summary(model):
4    """Print a detailed summary including output shapes and parameter counts."""
5    print(f"Model: {model.name}")
6    print(f"{'Layer':<30} {'Type':<25} {'Output Shape':<20} {'Params':<10} {'Trainable'}")
7    print("=" * 110)
8
9    total_params = 0
10    trainable_params = 0
11
12    for layer in model.layers:
13        output_shape = str(layer.output_shape) if hasattr(layer, 'output_shape') else 'N/A'
14        params = layer.count_params()
15        trainable = "Yes" if layer.trainable and params > 0 else "No"
16        total_params += params
17        if layer.trainable:
18            trainable_params += params
19
20        print(f"{layer.name:<30} {layer.__class__.__name__:<25} {output_shape:<20} {params:<10} {trainable}")
21
22    print("=" * 110)
23    print(f"Total params: {total_params:,}")
24    print(f"Trainable params: {trainable_params:,}")
25    print(f"Non-trainable params: {total_params - trainable_params:,}")
26
27detailed_summary(model)

Printing Summary to a String or File

python
1import tensorflow as tf
2import io
3
4model = tf.keras.Sequential([
5    tf.keras.layers.Dense(64, input_shape=(100,)),
6    tf.keras.layers.Dense(10),
7])
8
9# Capture summary as a string
10string_buffer = io.StringIO()
11model.summary(print_fn=lambda x: string_buffer.write(x + '\n'))
12summary_string = string_buffer.getvalue()
13print(summary_string)
14
15# Write summary to a file
16with open('model_summary.txt', 'w') as f:
17    model.summary(print_fn=lambda x: f.write(x + '\n'))
18
19# Log summary to TensorBoard
20import datetime
21log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
22writer = tf.summary.create_file_writer(log_dir)
23with writer.as_default():
24    tf.summary.text("model_summary", summary_string, step=0)

TensorFlow 1.x: Manual Variable Summary

python
1# TF1 — no model.summary(), iterate variables manually
2import tensorflow.compat.v1 as tf
3tf.disable_v2_behavior()
4
5# After building the graph
6total_params = 0
7for var in tf.trainable_variables():
8    shape = var.get_shape()
9    params = 1
10    for dim in shape:
11        params *= dim
12    total_params += params
13    print(f"{var.name:<50} {str(shape):<20} {params}")
14
15print(f"\nTotal trainable parameters: {total_params:,}")

Common Pitfalls

  • Calling summary() on an unbuilt subclassed model: Subclassed tf.keras.Model does not know its input shape until it processes data. Call model.build(input_shape=...) or model(sample_input) before calling summary(), otherwise it raises ValueError.
  • Expecting summary() to show nested model details: By default, summary() shows nested models (like a pretrained backbone) as a single line. Use model.summary(expand_nested=True) to recursively expand nested models and show all internal layers.
  • Missing graphviz for plot_model: tf.keras.utils.plot_model() requires both the Python pydot package and the system graphviz binary. Installing only the Python package without the system binary causes FileNotFoundError: dot not found.
  • Incorrect parameter count for shared layers: Layers used multiple times in a Functional model share parameters. summary() shows the correct total, but counting parameters by iterating model.layers can double-count shared layers. Use model.count_params() for the accurate total.
  • Summary showing "multiple" for output shapes: When a layer has multiple outputs or is used in multiple places, the output shape column shows "multiple" instead of a concrete shape. This is expected — use model.summary(expand_nested=True) or inspect individual layers with layer.output_shape for details.

Summary

  • Use model.summary() on any Keras model (Sequential, Functional, or subclassed) to display layer names, output shapes, and parameter counts
  • For subclassed models, call model.build(input_shape=...) or pass sample data before calling summary()
  • Use tf.keras.utils.plot_model() to generate a visual diagram of the model architecture
  • Capture summary output as a string with print_fn=lambda x: buffer.write(x)
  • Use expand_nested=True to show internal layers of nested models
  • For TF1 graphs, manually iterate over tf.trainable_variables() to display variable shapes and parameter counts

Course illustration
Course illustration

All Rights Reserved.