Adding Loss Functions¶
Loss functions in Molfun return a dictionary of named scalar tensors, enabling callers to log individual loss terms without knowing the internals. The LOSS_REGISTRY provides name-based lookup so training strategies and heads can resolve losses at runtime.
The LossFunction interface¶
All losses inherit from LossFunction in molfun/losses/base.py:
class LossFunction(ABC, nn.Module):
"""
Abstract base for all Molfun loss functions.
Signature: loss_fn(preds, targets=None, batch=None) -> dict[str, Tensor]
"""
@abstractmethod
def forward(
self,
preds: torch.Tensor,
targets: Optional[torch.Tensor] = None,
batch: Optional[dict] = None,
) -> dict[str, torch.Tensor]:
"""Compute loss and return a dict of named scalar tensors."""
Key design points:
- Returns a dict, not a single tensor. This allows composite losses to expose individual terms for logging.
targetsis optional because some losses extract ground truth from thebatchdict directly (e.g., structure losses that need atom coordinates from the feature dict).batchis the raw feature dict from the DataLoader, forwarded for losses that need additional context.
The LossRegistry¶
The LossRegistry is separate from ModuleRegistry because loss functions have a simpler API:
LOSS_REGISTRY.register("name") # decorator
LOSS_REGISTRY["name"] # get class (raises KeyError if missing)
LOSS_REGISTRY["name"]() # instantiate with defaults
"name" in LOSS_REGISTRY # membership test
list(LOSS_REGISTRY) # all registered names
Example: Contact Map Loss¶
A contact map loss penalizes predicted structures where residue pairs that should be in contact (based on ground-truth distance maps) are too far apart, and vice versa.
Step 1: Create the loss file¶
Create molfun/losses/contact.py:
"""
Contact map loss: penalizes incorrect inter-residue distance predictions.
Useful as an auxiliary loss alongside FAPE/coordinate losses, providing
a complementary signal about the global topology of the predicted structure.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from molfun.losses.base import LossFunction, LOSS_REGISTRY
@LOSS_REGISTRY.register("contact")
class ContactMapLoss(LossFunction):
"""
Binary cross-entropy loss on predicted vs. ground-truth contact maps.
A contact is defined as two C-alpha atoms within ``threshold`` angstroms.
The loss operates on predicted positions and ground-truth positions,
both expected in the ``batch`` dict.
Args:
threshold: Distance threshold in angstroms for defining a contact.
weight: Scalar weight for the loss term.
min_seq_sep: Minimum sequence separation to consider (ignore trivial
contacts between adjacent residues).
"""
def __init__(
self,
threshold: float = 8.0,
weight: float = 1.0,
min_seq_sep: int = 6,
):
super().__init__()
self.threshold = threshold
self.weight = weight
self.min_seq_sep = min_seq_sep
def _compute_contact_map(
self, positions: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute binary contact map from C-alpha positions.
Args:
positions: [B, L, 3] C-alpha coordinates.
mask: [B, L] residue mask.
Returns:
[B, L, L] binary contact map (1 = contact, 0 = no contact).
"""
# Pairwise distances
diff = positions.unsqueeze(2) - positions.unsqueeze(1) # [B, L, L, 3]
dist = torch.sqrt((diff ** 2).sum(-1) + 1e-8) # [B, L, L]
contacts = (dist < self.threshold).float()
# Zero out contacts below minimum sequence separation
L = positions.shape[1]
seq_idx = torch.arange(L, device=positions.device)
sep = (seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0)).abs()
sep_mask = (sep >= self.min_seq_sep).float()
contacts = contacts * sep_mask.unsqueeze(0)
if mask is not None:
pair_mask = mask.unsqueeze(2) * mask.unsqueeze(1)
contacts = contacts * pair_mask
return contacts
def forward(
self,
preds: torch.Tensor,
targets: Optional[torch.Tensor] = None,
batch: Optional[dict] = None,
) -> dict[str, torch.Tensor]:
"""
Args:
preds: Predicted positions [B, L, 3].
targets: Ground-truth positions [B, L, 3]. If None, extracted
from ``batch["gt_positions"]``.
batch: Feature dict from the DataLoader.
Returns:
Dict with ``"contact_loss"`` key.
"""
if targets is None and batch is not None:
targets = batch["gt_positions"]
if targets is None:
raise ValueError("ContactMapLoss requires targets or batch['gt_positions']")
mask = batch.get("mask") if batch is not None else None
pred_contacts = self._compute_contact_map(preds, mask)
true_contacts = self._compute_contact_map(targets, mask)
# Binary cross-entropy (with logits for numerical stability)
# Convert pred distances to logits
pred_diff = preds.unsqueeze(2) - preds.unsqueeze(1)
pred_dist = torch.sqrt((pred_diff ** 2).sum(-1) + 1e-8)
pred_logits = self.threshold - pred_dist # positive = contact
loss = F.binary_cross_entropy_with_logits(
pred_logits, true_contacts, reduction="none"
)
# Apply sequence separation mask
L = preds.shape[1]
seq_idx = torch.arange(L, device=preds.device)
sep = (seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0)).abs()
sep_mask = (sep >= self.min_seq_sep).float().unsqueeze(0)
loss = loss * sep_mask
if mask is not None:
pair_mask = mask.unsqueeze(2) * mask.unsqueeze(1)
loss = loss * pair_mask
loss = loss.sum() / pair_mask.sum().clamp(min=1)
else:
loss = loss.mean()
return {"contact_loss": self.weight * loss}
Step 2: Register via init.py¶
Add the import to molfun/losses/__init__.py:
Testing¶
Create tests/losses/test_contact.py:
import pytest
import torch
from molfun.losses.base import LOSS_REGISTRY
class TestContactMapLoss:
@pytest.fixture
def loss_fn(self):
return LOSS_REGISTRY["contact"](threshold=8.0, min_seq_sep=3)
def test_registry_lookup(self):
assert "contact" in LOSS_REGISTRY
def test_returns_dict(self, loss_fn):
preds = torch.randn(2, 10, 3)
targets = torch.randn(2, 10, 3)
result = loss_fn(preds, targets)
assert isinstance(result, dict)
assert "contact_loss" in result
def test_scalar_output(self, loss_fn):
preds = torch.randn(2, 10, 3)
targets = torch.randn(2, 10, 3)
result = loss_fn(preds, targets)
assert result["contact_loss"].dim() == 0 # scalar
def test_perfect_prediction(self, loss_fn):
"""Loss should be low when prediction matches target."""
targets = torch.randn(2, 20, 3)
# Identical predictions
result_same = loss_fn(targets, targets)
# Wildly different predictions
result_diff = loss_fn(targets * 10, targets)
# Same should have lower loss
assert result_same["contact_loss"] < result_diff["contact_loss"]
def test_with_batch_dict(self, loss_fn):
"""Can extract targets from batch dict."""
preds = torch.randn(2, 10, 3)
batch = {"gt_positions": torch.randn(2, 10, 3)}
result = loss_fn(preds, batch=batch)
assert "contact_loss" in result
def test_with_mask(self, loss_fn):
preds = torch.randn(2, 10, 3)
targets = torch.randn(2, 10, 3)
mask = torch.ones(2, 10)
mask[:, -3:] = 0
batch = {"mask": mask}
result = loss_fn(preds, targets, batch=batch)
assert "contact_loss" in result
def test_gradient_flow(self, loss_fn):
preds = torch.randn(2, 10, 3, requires_grad=True)
targets = torch.randn(2, 10, 3)
result = loss_fn(preds, targets)
result["contact_loss"].backward()
assert preds.grad is not None
def test_weight_scaling(self):
"""Weight parameter scales the loss."""
preds = torch.randn(2, 10, 3)
targets = torch.randn(2, 10, 3)
loss_w1 = LOSS_REGISTRY["contact"](weight=1.0)(preds, targets)
loss_w5 = LOSS_REGISTRY["contact"](weight=5.0)(preds, targets)
torch.testing.assert_close(
loss_w5["contact_loss"],
loss_w1["contact_loss"] * 5.0,
rtol=1e-5, atol=1e-5,
)
Integration: Using in training¶
As the primary loss¶
from molfun import MolfunStructureModel
from molfun.training import HeadOnlyFinetune
strategy = HeadOnlyFinetune(lr=1e-3, loss_fn="contact")
model = MolfunStructureModel("openfold")
strategy.fit(model, train_loader, val_loader, epochs=20)
In combination with other losses¶
Since losses return dicts, you can compose them:
from molfun.losses.base import LOSS_REGISTRY
fape_fn = LOSS_REGISTRY["fape"]()
contact_fn = LOSS_REGISTRY["contact"](weight=0.5)
# In a custom training loop:
fape_losses = fape_fn(preds, targets, batch)
contact_losses = contact_fn(preds.positions, batch=batch)
all_losses = {**fape_losses, **contact_losses}
total = sum(all_losses.values())