Keras custom loss function Accessing current input pattern
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
In Keras, a loss passed to model.compile() is intentionally narrow: it receives only the expected output and the model prediction. That design works for ordinary supervised learning, but it surprises people when the loss also depends on the current input batch. The correct solution is usually to move the extra logic into the model or the training step rather than forcing the standard loss signature to do a job it was not designed for.
Why the Standard Loss Signature Is Limited
A plain custom loss in Keras looks like this:
That function receives y_true and y_pred. It does not receive x, the current batch input. This is not a bug in your code. It is the contract of the API.
If your objective is something like "make predictions close to the target, but penalize them more when a specific input feature is large," then loss= alone is the wrong layer of abstraction. You need a place in the training flow where the model can see both the inputs and the predictions at the same time.
Use add_loss() When the Penalty Belongs to the Model
Keras supports extra loss terms through add_loss(). This works well when the penalty is part of the model definition itself.
This pattern is clean because the model itself declares the extra penalty. During training, Keras automatically adds anything from model.losses to the main loss.
Use this approach when the penalty is structural, similar to regularization. If the rule belongs to the model definition, add_loss() is usually the most natural choice.
Use a Custom train_step() for Full Control
Sometimes the loss depends on the inputs, targets, predictions, and maybe additional bookkeeping. In that case, subclassing keras.Model and overriding train_step() is clearer than trying to hide the logic elsewhere.
This is the right tool when the training rule is procedural rather than architectural. It keeps the logic explicit and avoids strange side effects from trying to smuggle data through the wrong API surface.
Should You Pack Inputs into y_true
You can technically pack extra values into the target tensor and unpack them inside the loss function. For example, some codebases concatenate a real label with one or more input-derived values. That works, but it usually makes the training pipeline harder to read and debug.
It also creates awkward coupling between your data loader and your loss function. Someone reading the model later may think the target shape changed for a modeling reason when it actually changed only to bypass the loss signature. That is why add_loss() or train_step() is usually the better long-term design.
How to Choose Between the Two Approaches
Use add_loss() when the extra term behaves like a model-level penalty and can be computed naturally during call(). Use train_step() when the training algorithm itself needs custom control.
That distinction matters because it keeps your code easy to reason about:
- '
add_loss()keeps the built-infit()flow mostly intact.' - '
train_step()gives you full control when the default training loop is too restrictive.' - A plain
loss=function should be reserved for logic that only depends ony_trueandy_pred.
Common Pitfalls
One common mistake is assuming a custom loss can directly access the current input batch because the model obviously has an input. Keras deliberately separates those concepts, so the loss callable from compile() cannot see the batch input unless you redesign the training flow.
Another mistake is closing over symbolic tensors from model construction and expecting them to behave like ordinary batch values during training. That can lead to confusing shape errors or graph-mode issues.
Developers also sometimes overload y_true with extra tensors to work around the problem. That can run, but it makes the dataset format harder to understand and increases maintenance cost.
Finally, if you override train_step(), remember that you now own more of the training loop. Make sure you compute gradients, apply them, and return metrics in a consistent way.
Summary
- A normal Keras loss function only receives
y_trueandy_pred. - If the loss depends on the input batch, use
add_loss()or a customtrain_step(). - '
add_loss()is best for model-level penalties that fit naturally intocall().' - '
train_step()is best when the training procedure itself needs custom control.' - Packing extra input data into
y_truecan work, but it is usually a maintenance-heavy workaround.

