Making a Custom Model Compatible with HuggingFace Trainer

transformers
pytorch
training
TIL
Plug any custom PyTorch model into HuggingFace Trainer: forward must return a loss.
Author

Mahamadi NIKIEMA

Published

April 8, 2026

The contract HF Trainer expects

Trainer only needs one thing from your model: forward() must return an object with a .loss attribute when labels are provided. Everything else (optimizer, scheduler, gradient accumulation) is handled for you.

1. Wrap model + loss in one nn.Module

Use ModelOutput (a dataclass from transformers) as the return type because Trainer knows how to unpack it.

Let us define two classes:

  • Output with two attributes: loss, required by the Trainer, and predictions to store predictions
  • A simple linear regression with 10 features called SimpleModel
import torch
import torch.nn as nn
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional


@dataclass
class Output(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    predictions: Optional[torch.FloatTensor] = None


class SimpleModel(nn.Module):
    def __init__(self,):
        super().__init__()
        self.simple_layer = nn.Linear(in_features=10, out_features=1)

    def forward(self, x: torch.Tensor, labels=None):
        out = self.simple_layer(x)
        loss = None
        if labels is not None:
            loss = torch.nn.functional.mse_loss(out, labels.float())
        return Output(loss=loss, predictions=out)

2. Format the dataset with matching column names

The dataset column names must exactly match the argument names of forward().

In our case, forward() expects x and labels, and we use torch.rand to generate uniformly distributed data as follows:


from torch.utils.data import Dataset

class SimpleDataset(Dataset):
    def __init__(self, n=100):
        self.x = torch.rand(n, 10)
        self.labels = torch.rand(n, 1)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {"x": self.x[idx], "labels": self.labels[idx]}

3. Putting it all together

We can pass the model and the dataset directly to the trainer for this simple case; no collator is needed.

from transformers import Trainer, TrainingArguments

train_loader = SimpleDataset()
simple_model = SimpleModel()

trainer = Trainer(
    model=simple_model,
    args=TrainingArguments(
        output_dir="./output",
        per_device_train_batch_size=2,
        num_train_epochs=2,
        learning_rate=2e-5,
    ),
    train_dataset=train_loader,
)

trainer.train()

Key takeaways

What Why
Return ModelOutput with .loss Trainer checks for this attribute in the backward pass
Column names = forward arg names Trainer unpacks the batch into forward