Loss Functions¶
Molfun provides a registry of loss functions for structure prediction and
affinity tasks. All losses implement the LossFunction abstract base class
and are registered in LOSS_REGISTRY.
Quick Start¶
from molfun.losses import LOSS_REGISTRY, MSELoss, PearsonLoss
# Use the registry
loss_fn = LOSS_REGISTRY["mse"]()
result = loss_fn(preds, targets)
print(result) # {"affinity_loss": tensor(0.42)}
# Or instantiate directly
loss_fn = PearsonLoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(0.15)}
# Register a custom loss
from molfun.losses import LossFunction
@LOSS_REGISTRY.register("tmscore")
class TMScoreLoss(LossFunction):
def forward(self, preds, targets=None, batch=None):
score = compute_tm_score(preds, targets)
return {"tmscore_loss": 1.0 - score}
LOSS_REGISTRY¶
LossFunction (ABC)¶
LossFunction ¶
Bases: ABC, Module
Abstract base for all Molfun loss functions.
A LossFunction is a callable nn.Module with a unified signature:
loss_fn(preds, targets=None, batch=None) -> dict[str, Tensor]
preds — model predictions (tensor or TrunkOutput, depending on the loss) targets — ground truth labels (None when GT is embedded in batch) batch — raw feature dict forwarded from the DataLoader; required for structure losses that compare against atom-coordinate fields.
Returns a dict mapping loss-term names to scalar tensors so callers can log individual terms without knowing loss internals.
forward
abstractmethod
¶
forward(preds: Tensor, targets: Tensor | None = None, batch: dict | None = None) -> dict[str, torch.Tensor]
Compute loss and return a dict of named scalar tensors.
Abstract base class for all loss functions. Subclasses must implement
forward().
Forward Signature¶
def forward(
self,
preds: torch.Tensor,
targets: Optional[torch.Tensor] = None,
batch: Optional[dict] = None,
) -> dict[str, torch.Tensor]:
...
| Parameter | Type | Description |
|---|---|---|
preds |
Tensor |
Model predictions |
targets |
Tensor \| None |
Ground truth labels |
batch |
dict \| None |
Full feature dict from the DataLoader |
Returns: dict[str, Tensor] mapping loss term names to scalar tensors.
Built-in Losses¶
MSELoss¶
MSELoss ¶
Mean Squared Error loss for regression tasks.
from molfun.losses import MSELoss
loss_fn = MSELoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(...)}
MAELoss¶
MAELoss ¶
Mean Absolute Error (L1) loss.
HuberLoss¶
HuberLoss ¶
Bases: LossFunction
Huber (smooth L1) loss — behaves like MSE near zero, MAE for large errors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
delta
|
float
|
Threshold between quadratic and linear regions (default 1.0). |
1.0
|
Huber (Smooth L1) loss -- less sensitive to outliers than MSE.
| Parameter | Type | Default | Description |
|---|---|---|---|
delta |
float |
1.0 |
Threshold for switching between L1 and L2 |
PearsonLoss¶
PearsonLoss ¶
Bases: LossFunction
1 − Pearson correlation coefficient.
Optimizes rank ordering directly rather than absolute values. Useful when experimental affinities have systematic offsets. Requires batch_size ≥ 2.
1 minus Pearson correlation coefficient. Optimizes for ranking rather than absolute accuracy.
from molfun.losses import PearsonLoss
loss_fn = PearsonLoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(...)} # 0 = perfect correlation
OpenFoldLoss (FAPE)¶
OpenFoldLoss ¶
Bases: LossFunction
Composite structure loss that wraps OpenFold's AlphaFoldLoss.
Loss terms (weights configurable via config.loss or with_weights()): - fape Frame Aligned Point Error (backbone + side-chain) - supervised_chi Side-chain torsion angle loss - distogram Pairwise Cβ-distance distribution loss - plddt_loss Predicted lDDT confidence loss - masked_msa Masked MSA reconstruction (disabled by default) - experimentally_resolved Experimentally resolved atom loss (disabled by default) - violation Steric clash / bond geometry (disabled by default)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_config
|
ConfigDict
|
|
required |
disable_masked_msa
|
bool
|
Zero-out masked-MSA weight (default True).
Enable only if the data pipeline produces |
True
|
disable_experimentally_resolved
|
bool
|
Zero-out this term (default True). Enable only if PDB resolution metadata is present in the batch. |
True
|
fape_only
classmethod
¶
FAPE + supervised chi only — no MSA, distogram or pLDDT terms.
with_weights
classmethod
¶
Override individual loss-term weights.
Example::
OpenFoldLoss.with_weights(config, fape=1.0, masked_msa=0.0)
forward ¶
Compute OpenFold structure loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
preds
|
Raw OpenFold output dict (from TrunkOutput.extra["_raw_outputs"]).
Must contain keys: |
required | |
targets
|
Tensor | None
|
Unused (ground truth is embedded in |
None
|
batch
|
dict | None
|
Feature dict with ground truth fields produced by
|
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
|
Structure prediction loss from AlphaFold2/OpenFold, including FAPE (Frame Aligned Point Error) and auxiliary losses.
from molfun.losses import OpenFoldLoss
# Full loss (FAPE + aux)
loss_fn = OpenFoldLoss(config)
result = loss_fn(raw_outputs, batch=feature_dict)
# {"structure_loss": tensor(...), "fape": tensor(...), "aux": tensor(...)}
# FAPE only
loss_fn = OpenFoldLoss.fape_only(config)
Combining Losses¶
from molfun.losses import LOSS_REGISTRY
# Use multiple losses with weights
mse = LOSS_REGISTRY["mse"]()
pearson = LOSS_REGISTRY["pearson"]()
result_mse = mse(preds, targets)
result_pearson = pearson(preds, targets)
total = result_mse["affinity_loss"] + 0.5 * result_pearson["affinity_loss"]
total.backward()
Registered Names¶
| Registry Key | Class | Task |
|---|---|---|
"mse" |
MSELoss |
Affinity regression |
"mae" |
MAELoss |
Affinity regression |
"huber" |
HuberLoss |
Affinity regression |
"pearson" |
PearsonLoss |
Affinity ranking |
"openfold" |
OpenFoldLoss |
Structure prediction |