MolfunStructureModel¶
The central facade for all Molfun operations: loading pretrained models, running predictions, fine-tuning, module swapping, export, and Hub integration.
Quick Start¶
from molfun import MolfunStructureModel
# Load a pretrained model
model = MolfunStructureModel.from_pretrained("openfold_v2")
# Predict a structure
output = model.predict("MKFLILLFNILCLFPVLAADNH...")
# Fine-tune on a custom dataset
model.fit(
train_dataset=train_ds,
val_dataset=val_ds,
epochs=10,
strategy="lora",
)
# Save and push to hub
model.save("./my_model")
model.push_to_hub("myorg/finetuned-openfold")
Class Reference¶
MolfunStructureModel ¶
Unified API for protein structure models.
Wraps any registered adapter (OpenFold, ESMFold, ...) with a common interface for inference, fine-tuning, task heads, and checkpointing.
The model itself is agnostic to the training strategy. Strategies
(HeadOnly, LoRA, Partial, Full) are passed to fit() and handle
all freezing, param groups, schedulers, EMA, etc.
__init__ ¶
__init__(name: str, model: Module | None = None, config: object | None = None, weights: str | None = None, device: str = 'cuda', head: str | None = None, head_config: dict | None = None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Model backend ("openfold", "esmfold", ...). |
required |
model
|
Module | None
|
Pre-built nn.Module. If None, built from config. |
None
|
config
|
object | None
|
Backend-specific config object. |
None
|
weights
|
str | None
|
Path to model checkpoint. |
None
|
device
|
str
|
Target device. |
'cuda'
|
head
|
str | None
|
Task head name ("affinity"). |
None
|
head_config
|
dict | None
|
Head kwargs (single_dim, hidden_dim, ...). |
None
|
from_pretrained
classmethod
¶
from_pretrained(name: str = 'openfold', device: str = 'cpu', head: str | None = None, head_config: dict | None = None, cache_dir: str | None = None, force_download: bool = False) -> MolfunStructureModel
Load a pretrained model with automatic weight download.
Downloads weights to ~/.molfun/weights/<name>/ on first call,
then loads from cache on subsequent calls.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Pretrained model name. See |
'openfold'
|
device
|
str
|
Target device ("cpu" or "cuda"). |
'cpu'
|
head
|
str | None
|
Optional task head ("affinity", "structure"). |
None
|
head_config
|
dict | None
|
Head kwargs. |
None
|
cache_dir
|
str | None
|
Override weight cache directory. |
None
|
force_download
|
bool
|
Re-download even if cached. |
False
|
Returns:
| Type | Description |
|---|---|
MolfunStructureModel
|
Ready-to-use MolfunStructureModel. |
Usage::
model = MolfunStructureModel.from_pretrained("openfold")
output = model.predict("MKWVTFISLLLLFSSAYS")
predict ¶
Run inference (no grad, eval mode).
Accepts either a feature dict (batch) or a raw amino acid sequence string. When a string is passed, it is automatically featurized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_or_sequence
|
Feature dict or amino acid sequence string (e.g. "MKWVTFISLLLLFSSAYS"). |
required |
Returns:
| Type | Description |
|---|---|
TrunkOutput
|
TrunkOutput with single_repr, pair_repr, structure_coords, |
TrunkOutput
|
and confidence (pLDDT). |
Usage::
# From sequence string
output = model.predict("MKWVTFISLLLLFSSAYS")
# From pre-built feature dict
output = model.predict(batch)
forward ¶
Full forward: adapter → head.
Returns dict with "trunk_output" and optionally "preds". For StructureLossHead, "preds" is the scalar structure loss.
fit ¶
fit(train_loader: DataLoader, val_loader: DataLoader | None = None, strategy: FinetuneStrategy | None = None, epochs: int = 10, gradient_checkpointing: bool = False, tracker=None, checkpoint_dir: str | None = None, save_every: int = 0, resume_from: str | None = None) -> list[dict]
Fine-tune the model using the given strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_loader
|
DataLoader
|
Training data. |
required |
val_loader
|
DataLoader | None
|
Validation data (optional). |
None
|
strategy
|
FinetuneStrategy | None
|
FinetuneStrategy instance (HeadOnly, LoRA, Partial, Full). |
None
|
epochs
|
int
|
Number of training epochs. |
10
|
gradient_checkpointing
|
bool
|
Trade compute for VRAM (~40-60% savings). |
False
|
tracker
|
Optional BaseTracker (e.g. ExperimentRegistry) for logging. |
None
|
|
checkpoint_dir
|
str | None
|
Directory for periodic and best-model checkpoints. |
None
|
save_every
|
int
|
Save a checkpoint every N epochs (0 = only best). |
0
|
resume_from
|
str | None
|
Path to checkpoint directory to resume training from. |
None
|
Returns:
| Type | Description |
|---|---|
list[dict]
|
List of per-epoch metric dicts. |
from_custom
classmethod
¶
from_custom(adapter: BaseAdapter, device: str = 'cuda', head: str | None = None, head_config: dict | None = None) -> MolfunStructureModel
Create a MolfunStructureModel from a custom adapter (e.g. BuiltModel).
This bypasses the ADAPTER_REGISTRY and directly uses the provided adapter, enabling custom architectures built with ModelBuilder or hand-crafted nn.Modules that implement BaseAdapter.
Usage::
from molfun.modules.builder import ModelBuilder
built = ModelBuilder(
embedder="input", block="pairformer", structure_module="ipa",
).build()
model = MolfunStructureModel.from_custom(
built, head="affinity", head_config={"single_dim": 256},
)
swap ¶
Replace an internal submodule of the adapter's model.
Uses ModuleSwapper under the hood. The target_path is relative to the adapter's internal model (e.g. "structure_module", "evoformer.blocks.0").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_path
|
str
|
Dotted path to the submodule. |
required |
new_module
|
Module
|
Replacement module. |
required |
transfer_weights
|
bool
|
Copy matching weights from old module. |
False
|
Returns:
| Type | Description |
|---|---|
Module
|
The old (replaced) module. |
Usage::
from molfun.modules.structure_module import DiffusionStructureModule
model.swap("structure_module", DiffusionStructureModule(...))
swap_all ¶
Swap all submodules matching a regex pattern.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pattern
|
str
|
Regex pattern for module names. |
required |
factory
|
|
required | |
transfer_weights
|
bool
|
Copy matching weights from old modules. |
False
|
Returns:
| Type | Description |
|---|---|
int
|
Number of modules swapped. |
discover_modules ¶
List swappable modules inside the model.
push_to_hub ¶
push_to_hub(repo_id: str, token: str | None = None, private: bool = False, metrics: dict | None = None, dataset_name: str | None = None, commit_message: str = 'Upload Molfun model') -> str
Push model checkpoint + auto-generated model card to HF Hub.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_id
|
str
|
Hub repo (e.g. "user/my-affinity-model"). |
required |
token
|
str | None
|
HF API token (or set HF_TOKEN env var). |
None
|
private
|
bool
|
Whether the repo should be private. |
False
|
metrics
|
dict | None
|
Evaluation metrics to include in the model card. |
None
|
dataset_name
|
str | None
|
Training dataset name for the card. |
None
|
commit_message
|
str
|
Git commit message for the upload. |
'Upload Molfun model'
|
Returns:
| Type | Description |
|---|---|
str
|
URL of the uploaded repo. |
Usage::
model.push_to_hub("rubencr/kinase-affinity-lora", metrics={"mae": 0.42})
from_hub
classmethod
¶
from_hub(repo_id: str, token: str | None = None, revision: str | None = None, device: str = 'cpu', head: str | None = None, head_config: dict | None = None) -> MolfunStructureModel
Download and load a model from Hugging Face Hub.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_id
|
str
|
Hub repo (e.g. "user/my-affinity-model"). |
required |
token
|
str | None
|
HF API token. |
None
|
revision
|
str | None
|
Git revision (branch, tag, commit hash). |
None
|
device
|
str
|
Target device. |
'cpu'
|
head
|
str | None
|
Task head to attach (overrides saved head). |
None
|
head_config
|
dict | None
|
Head kwargs (overrides saved config). |
None
|
Returns:
| Type | Description |
|---|---|
MolfunStructureModel
|
Loaded MolfunStructureModel. |
Usage::
model = MolfunStructureModel.from_hub("rubencr/kinase-affinity-lora")
output = model.predict(batch)
export_onnx ¶
export_onnx(path: str, seq_len: int = 256, opset_version: int = 17, simplify: bool = False, device: str = 'cpu') -> Path
Export model to ONNX format for optimized inference.
Merge LoRA weights first if applicable::
model.merge()
model.export_onnx("model.onnx")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Output .onnx file path. |
required |
seq_len
|
int
|
Dummy sequence length for tracing. |
256
|
opset_version
|
int
|
ONNX opset version. |
17
|
simplify
|
bool
|
Run onnx-simplifier after export. |
False
|
device
|
str
|
Device for tracing ("cpu" recommended). |
'cpu'
|
Returns:
| Type | Description |
|---|---|
Path
|
Path to the exported file. |
export_torchscript ¶
export_torchscript(path: str, seq_len: int = 256, mode: str = 'trace', optimize: bool = True, device: str = 'cpu') -> Path
Export model to TorchScript for deployment without Python.
Merge LoRA weights first if applicable::
model.merge()
model.export_torchscript("model.pt")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Output .pt file path. |
required |
seq_len
|
int
|
Dummy sequence length for tracing. |
256
|
mode
|
str
|
"trace" (default) or "script". |
'trace'
|
optimize
|
bool
|
Apply inference optimizations. |
True
|
device
|
str
|
Device for tracing. |
'cpu'
|
Returns:
| Type | Description |
|---|---|
Path
|
Path to the exported file. |
example_dataset
staticmethod
¶
Fetch a small example dataset for quick experimentation.
Downloads PDB structures to ~/.molfun/examples/<name>/
and returns a list of file paths.
Available datasets:
"globins-small"— 20 globin structures (small, fast)"kinases-small"— 30 human kinases"gpcr-small"— 20 GPCR structures"mixed-tiny"— 10 mixed structures (fastest)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Example dataset name. |
'globins-small'
|
cache_dir
|
str | None
|
Override cache directory. |
None
|
Returns:
| Type | Description |
|---|---|
list[str]
|
List of paths to downloaded PDB/mmCIF files. |
Usage::
paths = MolfunStructureModel.example_dataset("globins-small")
print(f"Downloaded {len(paths)} structures")
available_pretrained
staticmethod
¶
List available pretrained model names for from_pretrained().
Class Methods¶
from_pretrained¶
Load a model from a named pretrained checkpoint.
model = MolfunStructureModel.from_pretrained(
name="openfold_v2",
device="cuda",
dtype=torch.float16,
)
| Parameter | Type | Default | Description |
|---|---|---|---|
name |
str |
required | Pretrained model name (see available_pretrained()) |
device |
str \| torch.device |
"cpu" |
Target device |
dtype |
torch.dtype |
torch.float32 |
Model precision |
Returns: MolfunStructureModel
from_hub¶
Download and load a model from the Molfun Hub.
| Parameter | Type | Default | Description |
|---|---|---|---|
repo_id |
str |
required | Hub repository identifier |
revision |
str \| None |
None |
Specific revision / tag |
device |
str \| torch.device |
"cpu" |
Target device |
Returns: MolfunStructureModel
from_custom¶
Build a model from modular components via ModelBuilder.
model = MolfunStructureModel.from_custom(
embedder="input",
block="pairformer",
n_blocks=48,
structure_module="ipa",
)
| Parameter | Type | Default | Description |
|---|---|---|---|
embedder |
str |
required | Registered embedder name |
block |
str |
required | Registered block name |
n_blocks |
int |
required | Number of trunk blocks |
structure_module |
str |
required | Registered structure module name |
**configs |
dict |
{} |
Extra configuration passed to each component |
Returns: MolfunStructureModel
Instance Methods¶
predict¶
Run inference on one or more sequences.
| Parameter | Type | Default | Description |
|---|---|---|---|
sequence |
str \| list[str] |
required | Amino acid sequence(s) |
num_recycles |
int |
3 |
Number of recycling iterations |
msa |
str \| None |
None |
Path to MSA file (A3M) |
Returns: TrunkOutput with .positions, .plddt, .pae attributes.
fit¶
Fine-tune the model using the selected strategy.
model.fit(
train_dataset=train_ds,
val_dataset=val_ds,
epochs=10,
strategy="lora",
lr=1e-4,
batch_size=2,
tracker="wandb",
)
| Parameter | Type | Default | Description |
|---|---|---|---|
train_dataset |
Dataset |
required | Training dataset |
val_dataset |
Dataset \| None |
None |
Validation dataset |
epochs |
int |
10 |
Number of training epochs |
strategy |
str |
"full" |
Fine-tuning strategy: "full", "head_only", "lora", "partial" |
lr |
float |
1e-4 |
Learning rate |
batch_size |
int |
1 |
Batch size |
tracker |
str \| BaseTracker \| None |
None |
Experiment tracker |
**kwargs |
dict |
{} |
Additional strategy-specific arguments |
Returns: Training metrics dict.
forward¶
Low-level forward pass (used internally by predict and fit).
| Parameter | Type | Description |
|---|---|---|
batch |
Batch \| dict |
Feature dictionary from a DataLoader |
Returns: TrunkOutput
save / load¶
Persist and restore model checkpoints.
merge / unmerge¶
Merge or unmerge LoRA adapters into the base weights.
swap / swap_all¶
Replace individual or all modules of a given type at runtime.
# Swap a single attention module
model.swap("attention", "flash")
# Swap all blocks
model.swap_all("block", "pairformer")
| Parameter | Type | Description |
|---|---|---|
module_type |
str |
Module category: "attention", "block", "embedder", "structure_module" |
name |
str |
Name of the registered replacement module |
discover_modules¶
List all swappable modules currently in the model.
modules = model.discover_modules()
# {"attention": ["layer_0", "layer_1", ...], "block": [...], ...}
Returns: dict[str, list[str]]
push_to_hub¶
Upload the model to the Molfun Hub.
export_onnx / export_torchscript¶
Export the model for deployment.
example_dataset¶
Create a small synthetic dataset for testing.
Static / Class Methods¶
available_models¶
available_heads¶
available_pretrained¶
summary¶
Print a summary of model architecture and parameter counts.