Skip to content

Fine-Tuning Strategies

Molfun provides four fine-tuning strategies, all implementing the FinetuneStrategy abstract base class. Each strategy defines which parameters are trainable and how they are grouped for the optimizer.

Quick Start

from molfun import MolfunStructureModel

model = MolfunStructureModel.from_pretrained("openfold_v2")

# Via the model facade
model.fit(train_dataset=ds, strategy="lora", epochs=5)

# Or instantiate directly
from molfun.training.lora import LoRAFinetune

strategy = LoRAFinetune(rank=8, alpha=16)
strategy.fit(model, train_dataset=ds, epochs=5)

Base Class

FinetuneStrategy

Bases: ABC

Base class for all fine-tuning strategies.

Subclasses implement setup() (freeze/unfreeze logic) and param_groups() (optimizer parameter groups with per-group LR).

The base provides the full training loop with: - Linear warmup + cosine/linear/constant scheduler - Exponential Moving Average (EMA) - Gradient accumulation - Mixed precision (AMP) - Gradient clipping - Early stopping on val_loss

setup

setup(model) -> None

Configure the model for this strategy: freeze/unfreeze params, inject PEFT layers, attach heads, etc. Called once before training. Idempotent: calling multiple times has no effect.

param_groups abstractmethod

param_groups(model) -> list[dict]

Return optimizer parameter groups.

[

{"params": head_params, "lr": 1e-3}, {"params": lora_params, "lr": 1e-4},

]

fit

fit(model, train_loader: DataLoader, val_loader: DataLoader | None = None, epochs: int = 10, verbose: bool = True, tracker=None, distributed=None, gradient_checkpointing: bool = False, checkpoint_dir: str | None = None, save_every: int = 0, resume_from: str | None = None) -> list[dict]

Run the full training loop.

Parameters:

Name Type Description Default
model

MolfunStructureModel instance.

required
train_loader DataLoader

Training data.

required
val_loader DataLoader | None

Optional validation data.

None
epochs int

Number of epochs.

10
tracker

Optional BaseTracker for experiment logging.

None
distributed

Optional BaseDistributedStrategy (DDPStrategy or FSDPStrategy) for multi-GPU training.

None
gradient_checkpointing bool

Enable activation checkpointing to reduce peak VRAM (~40-60% savings, ~25-35% slower).

False
checkpoint_dir str | None

Directory for saving periodic and best checkpoints. If None, no intermediate checkpoints are saved.

None
save_every int

Save a checkpoint every N epochs (0 = only best).

0
resume_from str | None

Path to a checkpoint directory to resume training from. Loads model weights and resumes from the saved epoch.

None

Returns:

Type Description
list[dict]

List of per-epoch metric dicts.

apply_ema

apply_ema(model) -> None

Copy EMA weights into the model permanently (for export/inference).

The fit() template method implements the training loop. Subclasses customize behaviour through _setup_impl() and param_groups().


FullFinetune

FullFinetune

Bases: FinetuneStrategy

Unfreeze the entire model with layer-wise LR decay.

Each trunk block gets: lr * layer_lr_decay ^ (N - block_index). The head gets the base LR. Input embedder and structure module get separate configurable LRs.

Works with any adapter — the number of blocks and component access is resolved at setup time via the adapter's generic interface.

Usage::

strategy = FullFinetune(
    lr=1e-5, lr_head=1e-3,
    layer_lr_decay=0.9,
    warmup_steps=1000, ema_decay=0.999,
    accumulation_steps=8,
)
history = strategy.fit(model, train_loader, val_loader, epochs=5)

Fine-tune all model parameters. Highest capacity but requires the most memory and data.

from molfun.training.full import FullFinetune

strategy = FullFinetune(lr=1e-4, weight_decay=0.01)
strategy.fit(model, train_dataset=ds, val_dataset=val_ds, epochs=20)
Parameter Type Default Description
lr float 1e-4 Learning rate
weight_decay float 0.0 L2 regularization
max_grad_norm float \| None 1.0 Gradient clipping norm
warmup_steps int 0 Linear warmup steps
scheduler str "cosine" LR scheduler type

HeadOnlyFinetune

HeadOnlyFinetune

Bases: FinetuneStrategy

Freeze every trunk parameter, train only the prediction head.

Usage

strategy = HeadOnlyFinetune(lr=1e-3, warmup_steps=50) history = strategy.fit(model, train_loader, val_loader, epochs=20)

Freeze the trunk and fine-tune only the prediction head. Fast and memory-efficient; best for transfer learning with limited data.

from molfun.training.head_only import HeadOnlyFinetune

strategy = HeadOnlyFinetune(lr=5e-4)
strategy.fit(model, train_dataset=ds, epochs=50)
Parameter Type Default Description
lr float 5e-4 Learning rate
weight_decay float 0.0 L2 regularization
max_grad_norm float \| None 1.0 Gradient clipping norm

LoRAFinetune

LoRAFinetune

Bases: HeadOnlyFinetune

Freeze trunk, inject LoRA into attention layers, train LoRA + head.

When target_modules is None, the adapter's default_peft_targets is used — each adapter knows the naming convention of its own attention projections (OpenFold: linear_q/linear_v, Protenix: q_proj/v_proj, etc.). This makes LoRA work out of the box with any registered adapter.

Usage::

strategy = LoRAFinetune(
    rank=8, alpha=16.0,
    lr_head=1e-3, lr_lora=1e-4,
    warmup_steps=200, ema_decay=0.999,
    accumulation_steps=4,
)
history = strategy.fit(model, train_loader, val_loader, epochs=10)

# Export merged weights
model.merge()

Apply LoRA adapters and fine-tune only the adapter parameters. Excellent trade-off between capacity and efficiency.

from molfun.training.lora import LoRAFinetune

strategy = LoRAFinetune(
    rank=8,
    alpha=16,
    target_modules=["q_proj", "v_proj"],
    lr=2e-4,
)
strategy.fit(model, train_dataset=ds, epochs=10)
Parameter Type Default Description
rank int 8 LoRA rank
alpha float 16.0 LoRA scaling factor
dropout float 0.0 LoRA dropout
target_modules list[str] \| None None Module name patterns to adapt
lr float 2e-4 Learning rate
weight_decay float 0.0 L2 regularization
max_grad_norm float \| None 1.0 Gradient clipping norm

PartialFinetune

PartialFinetune

Bases: FinetuneStrategy

Freeze everything except the last unfreeze_last_n trunk blocks, the structure module (optionally), and the task head.

Works with any adapter that implements get_trunk_blocks() — OpenFold (Evoformer blocks), Protenix (Pairformer blocks), ESMFold (transformer layers), or custom models built with ModelBuilder.

Supports separate LRs for trunk vs head (discriminative fine-tuning).

Usage::

strategy = PartialFinetune(
    unfreeze_last_n=6,
    unfreeze_structure_module=True,
    lr_trunk=1e-5, lr_head=1e-3,
    warmup_steps=500, ema_decay=0.999,
)
history = strategy.fit(model, train_loader, val_loader, epochs=10)

Unfreeze the last N trunk blocks plus the prediction head. A middle ground between full fine-tuning and head-only.

from molfun.training.partial import PartialFinetune

strategy = PartialFinetune(
    n_unfrozen_blocks=4,
    lr=1e-4,
    head_lr=5e-4,
)
strategy.fit(model, train_dataset=ds, epochs=15)
Parameter Type Default Description
n_unfrozen_blocks int 4 Number of trunk blocks to unfreeze (from the end)
lr float 1e-4 Learning rate for unfrozen trunk blocks
head_lr float \| None None Separate learning rate for the head (defaults to lr)
weight_decay float 0.0 L2 regularization
max_grad_norm float \| None 1.0 Gradient clipping norm

Comparison

Strategy Trainable Params Memory Best For
FullFinetune 100% High Large datasets, maximum accuracy
HeadOnlyFinetune ~1-2% Low Small datasets, fast iteration
LoRAFinetune ~0.3-1% Low General-purpose fine-tuning
PartialFinetune ~10-30% Medium Moderate data, domain adaptation