Fine-Tuning Strategies¶
Molfun provides four fine-tuning strategies, all implementing the
FinetuneStrategy abstract base class. Each strategy defines which
parameters are trainable and how they are grouped for the optimizer.
Quick Start¶
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
# Via the model facade
model.fit(train_dataset=ds, strategy="lora", epochs=5)
# Or instantiate directly
from molfun.training.lora import LoRAFinetune
strategy = LoRAFinetune(rank=8, alpha=16)
strategy.fit(model, train_dataset=ds, epochs=5)
Base Class¶
FinetuneStrategy ¶
Bases: ABC
Base class for all fine-tuning strategies.
Subclasses implement setup() (freeze/unfreeze logic) and
param_groups() (optimizer parameter groups with per-group LR).
The base provides the full training loop with: - Linear warmup + cosine/linear/constant scheduler - Exponential Moving Average (EMA) - Gradient accumulation - Mixed precision (AMP) - Gradient clipping - Early stopping on val_loss
setup ¶
Configure the model for this strategy: freeze/unfreeze params, inject PEFT layers, attach heads, etc. Called once before training. Idempotent: calling multiple times has no effect.
param_groups
abstractmethod
¶
Return optimizer parameter groups.
[
{"params": head_params, "lr": 1e-3}, {"params": lora_params, "lr": 1e-4},
]
fit ¶
fit(model, train_loader: DataLoader, val_loader: DataLoader | None = None, epochs: int = 10, verbose: bool = True, tracker=None, distributed=None, gradient_checkpointing: bool = False, checkpoint_dir: str | None = None, save_every: int = 0, resume_from: str | None = None) -> list[dict]
Run the full training loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
MolfunStructureModel instance. |
required | |
train_loader
|
DataLoader
|
Training data. |
required |
val_loader
|
DataLoader | None
|
Optional validation data. |
None
|
epochs
|
int
|
Number of epochs. |
10
|
tracker
|
Optional BaseTracker for experiment logging. |
None
|
|
distributed
|
Optional |
None
|
|
gradient_checkpointing
|
bool
|
Enable activation checkpointing to reduce peak VRAM (~40-60% savings, ~25-35% slower). |
False
|
checkpoint_dir
|
str | None
|
Directory for saving periodic and best checkpoints. If None, no intermediate checkpoints are saved. |
None
|
save_every
|
int
|
Save a checkpoint every N epochs (0 = only best). |
0
|
resume_from
|
str | None
|
Path to a checkpoint directory to resume training from. Loads model weights and resumes from the saved epoch. |
None
|
Returns:
| Type | Description |
|---|---|
list[dict]
|
List of per-epoch metric dicts. |
apply_ema ¶
Copy EMA weights into the model permanently (for export/inference).
The fit() template method implements the training loop. Subclasses
customize behaviour through _setup_impl() and param_groups().
FullFinetune¶
FullFinetune ¶
Bases: FinetuneStrategy
Unfreeze the entire model with layer-wise LR decay.
Each trunk block gets: lr * layer_lr_decay ^ (N - block_index).
The head gets the base LR. Input embedder and structure module
get separate configurable LRs.
Works with any adapter — the number of blocks and component access is resolved at setup time via the adapter's generic interface.
Usage::
strategy = FullFinetune(
lr=1e-5, lr_head=1e-3,
layer_lr_decay=0.9,
warmup_steps=1000, ema_decay=0.999,
accumulation_steps=8,
)
history = strategy.fit(model, train_loader, val_loader, epochs=5)
Fine-tune all model parameters. Highest capacity but requires the most memory and data.
from molfun.training.full import FullFinetune
strategy = FullFinetune(lr=1e-4, weight_decay=0.01)
strategy.fit(model, train_dataset=ds, val_dataset=val_ds, epochs=20)
| Parameter | Type | Default | Description |
|---|---|---|---|
lr |
float |
1e-4 |
Learning rate |
weight_decay |
float |
0.0 |
L2 regularization |
max_grad_norm |
float \| None |
1.0 |
Gradient clipping norm |
warmup_steps |
int |
0 |
Linear warmup steps |
scheduler |
str |
"cosine" |
LR scheduler type |
HeadOnlyFinetune¶
HeadOnlyFinetune ¶
Bases: FinetuneStrategy
Freeze every trunk parameter, train only the prediction head.
Usage
strategy = HeadOnlyFinetune(lr=1e-3, warmup_steps=50) history = strategy.fit(model, train_loader, val_loader, epochs=20)
Freeze the trunk and fine-tune only the prediction head. Fast and memory-efficient; best for transfer learning with limited data.
from molfun.training.head_only import HeadOnlyFinetune
strategy = HeadOnlyFinetune(lr=5e-4)
strategy.fit(model, train_dataset=ds, epochs=50)
| Parameter | Type | Default | Description |
|---|---|---|---|
lr |
float |
5e-4 |
Learning rate |
weight_decay |
float |
0.0 |
L2 regularization |
max_grad_norm |
float \| None |
1.0 |
Gradient clipping norm |
LoRAFinetune¶
LoRAFinetune ¶
Bases: HeadOnlyFinetune
Freeze trunk, inject LoRA into attention layers, train LoRA + head.
When target_modules is None, the adapter's default_peft_targets
is used — each adapter knows the naming convention of its own attention
projections (OpenFold: linear_q/linear_v, Protenix: q_proj/v_proj,
etc.). This makes LoRA work out of the box with any registered adapter.
Usage::
strategy = LoRAFinetune(
rank=8, alpha=16.0,
lr_head=1e-3, lr_lora=1e-4,
warmup_steps=200, ema_decay=0.999,
accumulation_steps=4,
)
history = strategy.fit(model, train_loader, val_loader, epochs=10)
# Export merged weights
model.merge()
Apply LoRA adapters and fine-tune only the adapter parameters. Excellent trade-off between capacity and efficiency.
from molfun.training.lora import LoRAFinetune
strategy = LoRAFinetune(
rank=8,
alpha=16,
target_modules=["q_proj", "v_proj"],
lr=2e-4,
)
strategy.fit(model, train_dataset=ds, epochs=10)
| Parameter | Type | Default | Description |
|---|---|---|---|
rank |
int |
8 |
LoRA rank |
alpha |
float |
16.0 |
LoRA scaling factor |
dropout |
float |
0.0 |
LoRA dropout |
target_modules |
list[str] \| None |
None |
Module name patterns to adapt |
lr |
float |
2e-4 |
Learning rate |
weight_decay |
float |
0.0 |
L2 regularization |
max_grad_norm |
float \| None |
1.0 |
Gradient clipping norm |
PartialFinetune¶
PartialFinetune ¶
Bases: FinetuneStrategy
Freeze everything except the last unfreeze_last_n trunk blocks,
the structure module (optionally), and the task head.
Works with any adapter that implements get_trunk_blocks() — OpenFold
(Evoformer blocks), Protenix (Pairformer blocks), ESMFold (transformer
layers), or custom models built with ModelBuilder.
Supports separate LRs for trunk vs head (discriminative fine-tuning).
Usage::
strategy = PartialFinetune(
unfreeze_last_n=6,
unfreeze_structure_module=True,
lr_trunk=1e-5, lr_head=1e-3,
warmup_steps=500, ema_decay=0.999,
)
history = strategy.fit(model, train_loader, val_loader, epochs=10)
Unfreeze the last N trunk blocks plus the prediction head. A middle ground between full fine-tuning and head-only.
from molfun.training.partial import PartialFinetune
strategy = PartialFinetune(
n_unfrozen_blocks=4,
lr=1e-4,
head_lr=5e-4,
)
strategy.fit(model, train_dataset=ds, epochs=15)
| Parameter | Type | Default | Description |
|---|---|---|---|
n_unfrozen_blocks |
int |
4 |
Number of trunk blocks to unfreeze (from the end) |
lr |
float |
1e-4 |
Learning rate for unfrozen trunk blocks |
head_lr |
float \| None |
None |
Separate learning rate for the head (defaults to lr) |
weight_decay |
float |
0.0 |
L2 regularization |
max_grad_norm |
float \| None |
1.0 |
Gradient clipping norm |
Comparison¶
| Strategy | Trainable Params | Memory | Best For |
|---|---|---|---|
| FullFinetune | 100% | High | Large datasets, maximum accuracy |
| HeadOnlyFinetune | ~1-2% | Low | Small datasets, fast iteration |
| LoRAFinetune | ~0.3-1% | Low | General-purpose fine-tuning |
| PartialFinetune | ~10-30% | Medium | Moderate data, domain adaptation |