Making a Custom Model Compatible with HuggingFace Trainer
transformers
pytorch
training
TIL
Plug any custom PyTorch model into HuggingFace Trainer: forward must return a loss.
The contract HF Trainer expects
Trainer only needs one thing from your model: forward() must return an object with a .loss attribute when labels are provided. Everything else (optimizer, scheduler, gradient accumulation) is handled for you.
1. Wrap model + loss in one nn.Module
Use ModelOutput (a dataclass from transformers) as the return type because Trainer knows how to unpack it.
Let us define two classes:
Outputwith two attributes:loss, required by theTrainer, andpredictionsto store predictions- A simple linear regression with 10 features called
SimpleModel
import torch
import torch.nn as nn
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional
@dataclass
class Output(ModelOutput):
loss: Optional[torch.FloatTensor] = None
predictions: Optional[torch.FloatTensor] = None
class SimpleModel(nn.Module):
def __init__(self,):
super().__init__()
self.simple_layer = nn.Linear(in_features=10, out_features=1)
def forward(self, x: torch.Tensor, labels=None):
out = self.simple_layer(x)
loss = None
if labels is not None:
loss = torch.nn.functional.mse_loss(out, labels.float())
return Output(loss=loss, predictions=out)2. Format the dataset with matching column names
The dataset column names must exactly match the argument names of forward().
In our case, forward() expects x and labels, and we use torch.rand to generate uniformly distributed data as follows:
from torch.utils.data import Dataset
class SimpleDataset(Dataset):
def __init__(self, n=100):
self.x = torch.rand(n, 10)
self.labels = torch.rand(n, 1)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return {"x": self.x[idx], "labels": self.labels[idx]}3. Putting it all together
We can pass the model and the dataset directly to the trainer for this simple case; no collator is needed.
from transformers import Trainer, TrainingArguments
train_loader = SimpleDataset()
simple_model = SimpleModel()
trainer = Trainer(
model=simple_model,
args=TrainingArguments(
output_dir="./output",
per_device_train_batch_size=2,
num_train_epochs=2,
learning_rate=2e-5,
),
train_dataset=train_loader,
)
trainer.train()Key takeaways
| What | Why |
|---|---|
Return ModelOutput with .loss |
Trainer checks for this attribute in the backward pass |
Column names = forward arg names |
Trainer unpacks the batch into forward |