Skip to content

MolfunStructureModel

The central facade for all Molfun operations: loading pretrained models, running predictions, fine-tuning, module swapping, export, and Hub integration.

Quick Start

from molfun import MolfunStructureModel

# Load a pretrained model
model = MolfunStructureModel.from_pretrained("openfold_v2")

# Predict a structure
output = model.predict("MKFLILLFNILCLFPVLAADNH...")

# Fine-tune on a custom dataset
model.fit(
    train_dataset=train_ds,
    val_dataset=val_ds,
    epochs=10,
    strategy="lora",
)

# Save and push to hub
model.save("./my_model")
model.push_to_hub("myorg/finetuned-openfold")

Class Reference

MolfunStructureModel

Unified API for protein structure models.

Wraps any registered adapter (OpenFold, ESMFold, ...) with a common interface for inference, fine-tuning, task heads, and checkpointing.

The model itself is agnostic to the training strategy. Strategies (HeadOnly, LoRA, Partial, Full) are passed to fit() and handle all freezing, param groups, schedulers, EMA, etc.

__init__

__init__(name: str, model: Module | None = None, config: object | None = None, weights: str | None = None, device: str = 'cuda', head: str | None = None, head_config: dict | None = None)

Parameters:

Name Type Description Default
name str

Model backend ("openfold", "esmfold", ...).

required
model Module | None

Pre-built nn.Module. If None, built from config.

None
config object | None

Backend-specific config object.

None
weights str | None

Path to model checkpoint.

None
device str

Target device.

'cuda'
head str | None

Task head name ("affinity").

None
head_config dict | None

Head kwargs (single_dim, hidden_dim, ...).

None

from_pretrained classmethod

from_pretrained(name: str = 'openfold', device: str = 'cpu', head: str | None = None, head_config: dict | None = None, cache_dir: str | None = None, force_download: bool = False) -> MolfunStructureModel

Load a pretrained model with automatic weight download.

Downloads weights to ~/.molfun/weights/<name>/ on first call, then loads from cache on subsequent calls.

Parameters:

Name Type Description Default
name str

Pretrained model name. See available_pretrained().

'openfold'
device str

Target device ("cpu" or "cuda").

'cpu'
head str | None

Optional task head ("affinity", "structure").

None
head_config dict | None

Head kwargs.

None
cache_dir str | None

Override weight cache directory.

None
force_download bool

Re-download even if cached.

False

Returns:

Type Description
MolfunStructureModel

Ready-to-use MolfunStructureModel.

Usage::

model = MolfunStructureModel.from_pretrained("openfold")
output = model.predict("MKWVTFISLLLLFSSAYS")

predict

predict(batch_or_sequence, **kwargs) -> TrunkOutput

Run inference (no grad, eval mode).

Accepts either a feature dict (batch) or a raw amino acid sequence string. When a string is passed, it is automatically featurized.

Parameters:

Name Type Description Default
batch_or_sequence

Feature dict or amino acid sequence string (e.g. "MKWVTFISLLLLFSSAYS").

required

Returns:

Type Description
TrunkOutput

TrunkOutput with single_repr, pair_repr, structure_coords,

TrunkOutput

and confidence (pLDDT).

Usage::

# From sequence string
output = model.predict("MKWVTFISLLLLFSSAYS")

# From pre-built feature dict
output = model.predict(batch)

forward

forward(batch: dict, mask: Tensor | None = None) -> dict

Full forward: adapter → head.

Returns dict with "trunk_output" and optionally "preds". For StructureLossHead, "preds" is the scalar structure loss.

fit

fit(train_loader: DataLoader, val_loader: DataLoader | None = None, strategy: FinetuneStrategy | None = None, epochs: int = 10, gradient_checkpointing: bool = False, tracker=None, checkpoint_dir: str | None = None, save_every: int = 0, resume_from: str | None = None) -> list[dict]

Fine-tune the model using the given strategy.

Parameters:

Name Type Description Default
train_loader DataLoader

Training data.

required
val_loader DataLoader | None

Validation data (optional).

None
strategy FinetuneStrategy | None

FinetuneStrategy instance (HeadOnly, LoRA, Partial, Full).

None
epochs int

Number of training epochs.

10
gradient_checkpointing bool

Trade compute for VRAM (~40-60% savings).

False
tracker

Optional BaseTracker (e.g. ExperimentRegistry) for logging.

None
checkpoint_dir str | None

Directory for periodic and best-model checkpoints.

None
save_every int

Save a checkpoint every N epochs (0 = only best).

0
resume_from str | None

Path to checkpoint directory to resume training from.

None

Returns:

Type Description
list[dict]

List of per-epoch metric dicts.

save

save(path: str) -> None

Save PEFT adapters + head weights.

load

load(path: str) -> None

Load PEFT adapters + head weights.

merge

merge() -> None

Merge PEFT weights into base model for production export.

from_custom classmethod

from_custom(adapter: BaseAdapter, device: str = 'cuda', head: str | None = None, head_config: dict | None = None) -> MolfunStructureModel

Create a MolfunStructureModel from a custom adapter (e.g. BuiltModel).

This bypasses the ADAPTER_REGISTRY and directly uses the provided adapter, enabling custom architectures built with ModelBuilder or hand-crafted nn.Modules that implement BaseAdapter.

Usage::

from molfun.modules.builder import ModelBuilder
built = ModelBuilder(
    embedder="input", block="pairformer", structure_module="ipa",
).build()
model = MolfunStructureModel.from_custom(
    built, head="affinity", head_config={"single_dim": 256},
)

swap

swap(target_path: str, new_module: Module, transfer_weights: bool = False) -> nn.Module

Replace an internal submodule of the adapter's model.

Uses ModuleSwapper under the hood. The target_path is relative to the adapter's internal model (e.g. "structure_module", "evoformer.blocks.0").

Parameters:

Name Type Description Default
target_path str

Dotted path to the submodule.

required
new_module Module

Replacement module.

required
transfer_weights bool

Copy matching weights from old module.

False

Returns:

Type Description
Module

The old (replaced) module.

Usage::

from molfun.modules.structure_module import DiffusionStructureModule
model.swap("structure_module", DiffusionStructureModule(...))

swap_all

swap_all(pattern: str, factory, transfer_weights: bool = False) -> int

Swap all submodules matching a regex pattern.

Parameters:

Name Type Description Default
pattern str

Regex pattern for module names.

required
factory

factory(name, old_module) → new_module.

required
transfer_weights bool

Copy matching weights from old modules.

False

Returns:

Type Description
int

Number of modules swapped.

discover_modules

discover_modules(pattern: str | None = None) -> list[tuple[str, nn.Module]]

List swappable modules inside the model.

push_to_hub

push_to_hub(repo_id: str, token: str | None = None, private: bool = False, metrics: dict | None = None, dataset_name: str | None = None, commit_message: str = 'Upload Molfun model') -> str

Push model checkpoint + auto-generated model card to HF Hub.

Parameters:

Name Type Description Default
repo_id str

Hub repo (e.g. "user/my-affinity-model").

required
token str | None

HF API token (or set HF_TOKEN env var).

None
private bool

Whether the repo should be private.

False
metrics dict | None

Evaluation metrics to include in the model card.

None
dataset_name str | None

Training dataset name for the card.

None
commit_message str

Git commit message for the upload.

'Upload Molfun model'

Returns:

Type Description
str

URL of the uploaded repo.

Usage::

model.push_to_hub("rubencr/kinase-affinity-lora", metrics={"mae": 0.42})

from_hub classmethod

from_hub(repo_id: str, token: str | None = None, revision: str | None = None, device: str = 'cpu', head: str | None = None, head_config: dict | None = None) -> MolfunStructureModel

Download and load a model from Hugging Face Hub.

Parameters:

Name Type Description Default
repo_id str

Hub repo (e.g. "user/my-affinity-model").

required
token str | None

HF API token.

None
revision str | None

Git revision (branch, tag, commit hash).

None
device str

Target device.

'cpu'
head str | None

Task head to attach (overrides saved head).

None
head_config dict | None

Head kwargs (overrides saved config).

None

Returns:

Type Description
MolfunStructureModel

Loaded MolfunStructureModel.

Usage::

model = MolfunStructureModel.from_hub("rubencr/kinase-affinity-lora")
output = model.predict(batch)

export_onnx

export_onnx(path: str, seq_len: int = 256, opset_version: int = 17, simplify: bool = False, device: str = 'cpu') -> Path

Export model to ONNX format for optimized inference.

Merge LoRA weights first if applicable::

model.merge()
model.export_onnx("model.onnx")

Parameters:

Name Type Description Default
path str

Output .onnx file path.

required
seq_len int

Dummy sequence length for tracing.

256
opset_version int

ONNX opset version.

17
simplify bool

Run onnx-simplifier after export.

False
device str

Device for tracing ("cpu" recommended).

'cpu'

Returns:

Type Description
Path

Path to the exported file.

export_torchscript

export_torchscript(path: str, seq_len: int = 256, mode: str = 'trace', optimize: bool = True, device: str = 'cpu') -> Path

Export model to TorchScript for deployment without Python.

Merge LoRA weights first if applicable::

model.merge()
model.export_torchscript("model.pt")

Parameters:

Name Type Description Default
path str

Output .pt file path.

required
seq_len int

Dummy sequence length for tracing.

256
mode str

"trace" (default) or "script".

'trace'
optimize bool

Apply inference optimizations.

True
device str

Device for tracing.

'cpu'

Returns:

Type Description
Path

Path to the exported file.

example_dataset staticmethod

example_dataset(name: str = 'globins-small', cache_dir: str | None = None) -> list[str]

Fetch a small example dataset for quick experimentation.

Downloads PDB structures to ~/.molfun/examples/<name>/ and returns a list of file paths.

Available datasets:

  • "globins-small" — 20 globin structures (small, fast)
  • "kinases-small" — 30 human kinases
  • "gpcr-small" — 20 GPCR structures
  • "mixed-tiny" — 10 mixed structures (fastest)

Parameters:

Name Type Description Default
name str

Example dataset name.

'globins-small'
cache_dir str | None

Override cache directory.

None

Returns:

Type Description
list[str]

List of paths to downloaded PDB/mmCIF files.

Usage::

paths = MolfunStructureModel.example_dataset("globins-small")
print(f"Downloaded {len(paths)} structures")

available_pretrained staticmethod

available_pretrained() -> list[str]

List available pretrained model names for from_pretrained().

Class Methods

from_pretrained

Load a model from a named pretrained checkpoint.

model = MolfunStructureModel.from_pretrained(
    name="openfold_v2",
    device="cuda",
    dtype=torch.float16,
)
Parameter Type Default Description
name str required Pretrained model name (see available_pretrained())
device str \| torch.device "cpu" Target device
dtype torch.dtype torch.float32 Model precision

Returns: MolfunStructureModel


from_hub

Download and load a model from the Molfun Hub.

model = MolfunStructureModel.from_hub("myorg/finetuned-openfold")
Parameter Type Default Description
repo_id str required Hub repository identifier
revision str \| None None Specific revision / tag
device str \| torch.device "cpu" Target device

Returns: MolfunStructureModel


from_custom

Build a model from modular components via ModelBuilder.

model = MolfunStructureModel.from_custom(
    embedder="input",
    block="pairformer",
    n_blocks=48,
    structure_module="ipa",
)
Parameter Type Default Description
embedder str required Registered embedder name
block str required Registered block name
n_blocks int required Number of trunk blocks
structure_module str required Registered structure module name
**configs dict {} Extra configuration passed to each component

Returns: MolfunStructureModel


Instance Methods

predict

Run inference on one or more sequences.

output = model.predict(
    "MKFLILLFNILCLFPVLAADNH",
    num_recycles=3,
)
Parameter Type Default Description
sequence str \| list[str] required Amino acid sequence(s)
num_recycles int 3 Number of recycling iterations
msa str \| None None Path to MSA file (A3M)

Returns: TrunkOutput with .positions, .plddt, .pae attributes.


fit

Fine-tune the model using the selected strategy.

model.fit(
    train_dataset=train_ds,
    val_dataset=val_ds,
    epochs=10,
    strategy="lora",
    lr=1e-4,
    batch_size=2,
    tracker="wandb",
)
Parameter Type Default Description
train_dataset Dataset required Training dataset
val_dataset Dataset \| None None Validation dataset
epochs int 10 Number of training epochs
strategy str "full" Fine-tuning strategy: "full", "head_only", "lora", "partial"
lr float 1e-4 Learning rate
batch_size int 1 Batch size
tracker str \| BaseTracker \| None None Experiment tracker
**kwargs dict {} Additional strategy-specific arguments

Returns: Training metrics dict.


forward

Low-level forward pass (used internally by predict and fit).

output = model.forward(batch)
Parameter Type Description
batch Batch \| dict Feature dictionary from a DataLoader

Returns: TrunkOutput


save / load

Persist and restore model checkpoints.

model.save("./checkpoints/epoch_10")
model = MolfunStructureModel.load("./checkpoints/epoch_10")

merge / unmerge

Merge or unmerge LoRA adapters into the base weights.

model.merge()    # fuse adapters into base weights
model.unmerge()  # separate them back out

swap / swap_all

Replace individual or all modules of a given type at runtime.

# Swap a single attention module
model.swap("attention", "flash")

# Swap all blocks
model.swap_all("block", "pairformer")
Parameter Type Description
module_type str Module category: "attention", "block", "embedder", "structure_module"
name str Name of the registered replacement module

discover_modules

List all swappable modules currently in the model.

modules = model.discover_modules()
# {"attention": ["layer_0", "layer_1", ...], "block": [...], ...}

Returns: dict[str, list[str]]


push_to_hub

Upload the model to the Molfun Hub.

model.push_to_hub("myorg/finetuned-openfold", private=True)

export_onnx / export_torchscript

Export the model for deployment.

model.export_onnx("model.onnx", opset_version=17)
model.export_torchscript("model.pt")

example_dataset

Create a small synthetic dataset for testing.

ds = MolfunStructureModel.example_dataset(n=100, task="structure")

Static / Class Methods

available_models

MolfunStructureModel.available_models()
# ["openfold_v1", "openfold_v2", ...]

available_heads

MolfunStructureModel.available_heads()
# ["structure", "affinity", "property", ...]

available_pretrained

MolfunStructureModel.available_pretrained()
# ["openfold_v2", "openfold_finetuned_casp15", ...]

summary

Print a summary of model architecture and parameter counts.

model.summary()