Skip to content

Loss Functions

Molfun provides a registry of loss functions for structure prediction and affinity tasks. All losses implement the LossFunction abstract base class and are registered in LOSS_REGISTRY.

Quick Start

from molfun.losses import LOSS_REGISTRY, MSELoss, PearsonLoss

# Use the registry
loss_fn = LOSS_REGISTRY["mse"]()
result = loss_fn(preds, targets)
print(result)  # {"affinity_loss": tensor(0.42)}

# Or instantiate directly
loss_fn = PearsonLoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(0.15)}

# Register a custom loss
from molfun.losses import LossFunction

@LOSS_REGISTRY.register("tmscore")
class TMScoreLoss(LossFunction):
    def forward(self, preds, targets=None, batch=None):
        score = compute_tm_score(preds, targets)
        return {"tmscore_loss": 1.0 - score}

LOSS_REGISTRY

LOSS_REGISTRY module-attribute

LOSS_REGISTRY = LossRegistry()

LossFunction (ABC)

LossFunction

Bases: ABC, Module

Abstract base for all Molfun loss functions.

A LossFunction is a callable nn.Module with a unified signature:

loss_fn(preds, targets=None, batch=None) -> dict[str, Tensor]

preds — model predictions (tensor or TrunkOutput, depending on the loss) targets — ground truth labels (None when GT is embedded in batch) batch — raw feature dict forwarded from the DataLoader; required for structure losses that compare against atom-coordinate fields.

Returns a dict mapping loss-term names to scalar tensors so callers can log individual terms without knowing loss internals.

forward abstractmethod

forward(preds: Tensor, targets: Tensor | None = None, batch: dict | None = None) -> dict[str, torch.Tensor]

Compute loss and return a dict of named scalar tensors.

Abstract base class for all loss functions. Subclasses must implement forward().

Forward Signature

def forward(
    self,
    preds: torch.Tensor,
    targets: Optional[torch.Tensor] = None,
    batch: Optional[dict] = None,
) -> dict[str, torch.Tensor]:
    ...
Parameter Type Description
preds Tensor Model predictions
targets Tensor \| None Ground truth labels
batch dict \| None Full feature dict from the DataLoader

Returns: dict[str, Tensor] mapping loss term names to scalar tensors.


Built-in Losses

MSELoss

MSELoss

Bases: LossFunction

Mean Squared Error — standard choice for ΔG / pKd regression.

Mean Squared Error loss for regression tasks.

from molfun.losses import MSELoss

loss_fn = MSELoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(...)}

MAELoss

MAELoss

Bases: LossFunction

Mean Absolute Error — less sensitive to outliers than MSE.

Mean Absolute Error (L1) loss.

from molfun.losses import MAELoss

loss_fn = MAELoss()
result = loss_fn(preds, targets)

HuberLoss

HuberLoss

Bases: LossFunction

Huber (smooth L1) loss — behaves like MSE near zero, MAE for large errors.

Parameters:

Name Type Description Default
delta float

Threshold between quadratic and linear regions (default 1.0).

1.0

Huber (Smooth L1) loss -- less sensitive to outliers than MSE.

from molfun.losses import HuberLoss

loss_fn = HuberLoss(delta=1.0)
result = loss_fn(preds, targets)
Parameter Type Default Description
delta float 1.0 Threshold for switching between L1 and L2

PearsonLoss

PearsonLoss

Bases: LossFunction

1 − Pearson correlation coefficient.

Optimizes rank ordering directly rather than absolute values. Useful when experimental affinities have systematic offsets. Requires batch_size ≥ 2.

1 minus Pearson correlation coefficient. Optimizes for ranking rather than absolute accuracy.

from molfun.losses import PearsonLoss

loss_fn = PearsonLoss()
result = loss_fn(preds, targets)
# {"affinity_loss": tensor(...)}  # 0 = perfect correlation

OpenFoldLoss (FAPE)

OpenFoldLoss

Bases: LossFunction

Composite structure loss that wraps OpenFold's AlphaFoldLoss.

Loss terms (weights configurable via config.loss or with_weights()): - fape Frame Aligned Point Error (backbone + side-chain) - supervised_chi Side-chain torsion angle loss - distogram Pairwise Cβ-distance distribution loss - plddt_loss Predicted lDDT confidence loss - masked_msa Masked MSA reconstruction (disabled by default) - experimentally_resolved Experimentally resolved atom loss (disabled by default) - violation Steric clash / bond geometry (disabled by default)

Parameters:

Name Type Description Default
loss_config ConfigDict

config.loss sub-config from OpenFold's model_config().

required
disable_masked_msa bool

Zero-out masked-MSA weight (default True). Enable only if the data pipeline produces true_msa / bert_mask.

True
disable_experimentally_resolved bool

Zero-out this term (default True). Enable only if PDB resolution metadata is present in the batch.

True
fape_only classmethod
fape_only(config) -> OpenFoldLoss

FAPE + supervised chi only — no MSA, distogram or pLDDT terms.

with_weights classmethod
with_weights(config, **weights) -> OpenFoldLoss

Override individual loss-term weights.

Example::

OpenFoldLoss.with_weights(config, fape=1.0, masked_msa=0.0)
forward
forward(preds, targets: Tensor | None = None, batch: dict | None = None) -> dict[str, torch.Tensor]

Compute OpenFold structure loss.

Parameters:

Name Type Description Default
preds

Raw OpenFold output dict (from TrunkOutput.extra["_raw_outputs"]). Must contain keys: sm, distogram_logits, final_atom_positions, final_atom_mask, etc.

required
targets Tensor | None

Unused (ground truth is embedded in batch).

None
batch dict | None

Feature dict with ground truth fields produced by OpenFoldFeaturizer. Required.

None

Returns:

Type Description
dict[str, Tensor]

{"structure_loss": scalar_tensor}

describe
describe() -> dict

Return the active loss weights for logging / inspection.

Structure prediction loss from AlphaFold2/OpenFold, including FAPE (Frame Aligned Point Error) and auxiliary losses.

from molfun.losses import OpenFoldLoss

# Full loss (FAPE + aux)
loss_fn = OpenFoldLoss(config)
result = loss_fn(raw_outputs, batch=feature_dict)
# {"structure_loss": tensor(...), "fape": tensor(...), "aux": tensor(...)}

# FAPE only
loss_fn = OpenFoldLoss.fape_only(config)

Combining Losses

from molfun.losses import LOSS_REGISTRY

# Use multiple losses with weights
mse = LOSS_REGISTRY["mse"]()
pearson = LOSS_REGISTRY["pearson"]()

result_mse = mse(preds, targets)
result_pearson = pearson(preds, targets)

total = result_mse["affinity_loss"] + 0.5 * result_pearson["affinity_loss"]
total.backward()

Registered Names

Registry Key Class Task
"mse" MSELoss Affinity regression
"mae" MAELoss Affinity regression
"huber" HuberLoss Affinity regression
"pearson" PearsonLoss Affinity ranking
"openfold" OpenFoldLoss Structure prediction