How can I make a trainable parameter in keras?
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Creating Trainable Parameters in Keras
Trainable parameters are the core of neural network learning. They are the weights and biases that get updated during backpropagation to minimize the loss function. While Keras provides built-in layers like Dense and Conv2D that manage their own trainable parameters, there are situations where you need to define custom trainable variables. This is common when building custom layers, implementing attention mechanisms, or adding learnable scaling factors.
This article covers three practical approaches to creating trainable parameters in Keras with TensorFlow 2.x.
Approach 1: Custom Layer with self.add_weight
The most common and recommended way to create trainable parameters is by subclassing tf.keras.layers.Layer and using the add_weight method. This integrates cleanly with Keras model serialization, summary printing, and the training loop.
Using this layer in a model:
The build method is called automatically the first time the layer processes input. This deferred initialization means you do not need to specify the input dimension at construction time, which makes layers more flexible and reusable.
Approach 2: Learnable Scalar or Vector Parameters
Sometimes you need a single learnable scalar or a small vector that acts as a tunable coefficient. This is common in architectures that use learnable temperature scaling, attention score weighting, or feature gating.
This pattern is useful for adding a learnable normalization step or for implementing skip connections with a tunable blending factor:
Approach 3: Using tf.Variable Directly
For quick experiments or when working outside of the Keras layer system, you can create trainable parameters with tf.Variable and include them manually:
While this works, using add_weight is preferred because it handles registration, serialization, and device placement automatically. The tf.Variable approach is mainly useful when integrating non-Keras code or porting from raw TensorFlow.
Freezing and Unfreezing Parameters
You can control which parameters are trainable at runtime. This is essential for transfer learning, where you freeze pretrained layers and only train the new layers:
Inspecting Trainable Parameters
You can list all trainable parameters in a model to verify your custom layers are set up correctly:
Common Pitfalls
- Forgetting to call super().build(). If you skip this in your
buildmethod, the layer may not be marked as built, causing it to be rebuilt on every forward pass. - Not implementing get_config. Without
get_config, your custom layer cannot be serialized or loaded withmodel.saveandtf.keras.models.load_model. Always return any constructor arguments in the config dict. - Wrong initializer choice. Using
zerosfor weight matrices (not biases) can prevent the network from breaking symmetry during training. Useglorot_uniformorhe_normalfor weight matrices, andzerosfor biases. - Creating variables in init instead of build. Variables created in
__init__do not have access to the input shape, so you have to hardcode dimensions. Usebuildfor shape-dependent parameters. - Forgetting to recompile after changing trainable flags. Changes to
layer.trainableonly take effect after callingmodel.compileagain. Without recompiling, the optimizer state is stale and the frozen/unfrozen settings are not applied.
Summary
Keras provides several ways to create trainable parameters. The recommended approach is to subclass tf.keras.layers.Layer and use self.add_weight inside the build method, which gives you automatic shape inference, serialization support, and integration with the Keras training loop. For learnable scalars and vectors, the same pattern works with smaller shapes and constant initializers. Use tf.Variable directly only when you need raw TensorFlow compatibility. Always implement get_config for custom layers and choose appropriate initializers based on the role of each parameter.

