Keras
custom loss function
machine learning
neural networks
deep learning

Keras custom loss function Accessing current input pattern

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

In Keras, a loss passed to model.compile() is intentionally narrow: it receives only the expected output and the model prediction. That design works for ordinary supervised learning, but it surprises people when the loss also depends on the current input batch. The correct solution is usually to move the extra logic into the model or the training step rather than forcing the standard loss signature to do a job it was not designed for.

Why the Standard Loss Signature Is Limited

A plain custom loss in Keras looks like this:

python
1from tensorflow import keras
2import tensorflow as tf
3
4
5def mse_loss(y_true, y_pred):
6    return tf.reduce_mean(tf.square(y_true - y_pred), axis=-1)
7
8
9model = keras.Sequential(
10    [
11        keras.layers.Input(shape=(4,)),
12        keras.layers.Dense(8, activation="relu"),
13        keras.layers.Dense(1),
14    ]
15)
16
17model.compile(optimizer="adam", loss=mse_loss)

That function receives y_true and y_pred. It does not receive x, the current batch input. This is not a bug in your code. It is the contract of the API.

If your objective is something like "make predictions close to the target, but penalize them more when a specific input feature is large," then loss= alone is the wrong layer of abstraction. You need a place in the training flow where the model can see both the inputs and the predictions at the same time.

Use add_loss() When the Penalty Belongs to the Model

Keras supports extra loss terms through add_loss(). This works well when the penalty is part of the model definition itself.

python
1from tensorflow import keras
2import tensorflow as tf
3import numpy as np
4
5
6class InputAwarePenalty(keras.layers.Layer):
7    def call(self, inputs):
8        original_x, predictions = inputs
9
10        # Penalize large predictions when the first input feature is large.
11        feature_weight = tf.abs(original_x[:, :1])
12        penalty = tf.reduce_mean(feature_weight * tf.square(predictions))
13        self.add_loss(0.05 * penalty)
14        return predictions
15
16
17inputs = keras.Input(shape=(3,))
18x = keras.layers.Dense(16, activation="relu")(inputs)
19raw_output = keras.layers.Dense(1)(x)
20outputs = InputAwarePenalty()([inputs, raw_output])
21
22model = keras.Model(inputs, outputs)
23model.compile(optimizer="adam", loss="mse")
24
25x_train = np.array([[1.0, 0.2, 0.3], [0.0, 0.1, 0.2], [2.0, 0.5, 0.7]])
26y_train = np.array([[1.0], [0.0], [1.5]])
27
28model.fit(x_train, y_train, epochs=3, verbose=0)

This pattern is clean because the model itself declares the extra penalty. During training, Keras automatically adds anything from model.losses to the main loss.

Use this approach when the penalty is structural, similar to regularization. If the rule belongs to the model definition, add_loss() is usually the most natural choice.

Use a Custom train_step() for Full Control

Sometimes the loss depends on the inputs, targets, predictions, and maybe additional bookkeeping. In that case, subclassing keras.Model and overriding train_step() is clearer than trying to hide the logic elsewhere.

python
1from tensorflow import keras
2import tensorflow as tf
3import numpy as np
4
5
6class InputAwareModel(keras.Model):
7    def train_step(self, data):
8        x, y = data
9
10        with tf.GradientTape() as tape:
11            y_pred = self(x, training=True)
12            base_loss = tf.reduce_mean(tf.square(y - y_pred))
13            input_penalty = tf.reduce_mean(tf.abs(x[:, :1]) * tf.square(y_pred))
14            loss = base_loss + 0.05 * input_penalty
15
16        gradients = tape.gradient(loss, self.trainable_variables)
17        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
18        return {"loss": loss}
19
20
21model = InputAwareModel(
22    [
23        keras.layers.Input(shape=(3,)),
24        keras.layers.Dense(16, activation="relu"),
25        keras.layers.Dense(1),
26    ]
27)
28
29model.compile(optimizer="adam")
30
31x_train = np.random.random((32, 3)).astype("float32")
32y_train = np.random.random((32, 1)).astype("float32")
33model.fit(x_train, y_train, epochs=2, verbose=0)

This is the right tool when the training rule is procedural rather than architectural. It keeps the logic explicit and avoids strange side effects from trying to smuggle data through the wrong API surface.

Should You Pack Inputs into y_true

You can technically pack extra values into the target tensor and unpack them inside the loss function. For example, some codebases concatenate a real label with one or more input-derived values. That works, but it usually makes the training pipeline harder to read and debug.

It also creates awkward coupling between your data loader and your loss function. Someone reading the model later may think the target shape changed for a modeling reason when it actually changed only to bypass the loss signature. That is why add_loss() or train_step() is usually the better long-term design.

How to Choose Between the Two Approaches

Use add_loss() when the extra term behaves like a model-level penalty and can be computed naturally during call(). Use train_step() when the training algorithm itself needs custom control.

That distinction matters because it keeps your code easy to reason about:

  • 'add_loss() keeps the built-in fit() flow mostly intact.'
  • 'train_step() gives you full control when the default training loop is too restrictive.'
  • A plain loss= function should be reserved for logic that only depends on y_true and y_pred.

Common Pitfalls

One common mistake is assuming a custom loss can directly access the current input batch because the model obviously has an input. Keras deliberately separates those concepts, so the loss callable from compile() cannot see the batch input unless you redesign the training flow.

Another mistake is closing over symbolic tensors from model construction and expecting them to behave like ordinary batch values during training. That can lead to confusing shape errors or graph-mode issues.

Developers also sometimes overload y_true with extra tensors to work around the problem. That can run, but it makes the dataset format harder to understand and increases maintenance cost.

Finally, if you override train_step(), remember that you now own more of the training loop. Make sure you compute gradients, apply them, and return metrics in a consistent way.

Summary

  • A normal Keras loss function only receives y_true and y_pred.
  • If the loss depends on the input batch, use add_loss() or a custom train_step().
  • 'add_loss() is best for model-level penalties that fit naturally into call().'
  • 'train_step() is best when the training procedure itself needs custom control.'
  • Packing extra input data into y_true can work, but it is usually a maintenance-heavy workaround.

Course illustration
Course illustration

All Rights Reserved.