Skip to content

ModuleSwapper

Hot-swap modular components in a live model without rebuilding from scratch. Useful for architecture search, ablation studies, and runtime optimization.

Quick Start

from molfun.modules.swapper import ModuleSwapper
from molfun import MolfunStructureModel

model = MolfunStructureModel.from_pretrained("openfold_v2")

# Discover swappable modules
swappable = ModuleSwapper.discover(model)
print(swappable)
# {"attention": ["layer_0.attn", ...], "block": ["layer_0", ...], ...}

# Swap a single module
ModuleSwapper.swap(model, module_type="attention", name="flash")

# Swap all modules of a type
ModuleSwapper.swap_all(model, module_type="attention", name="linear")

# Swap by type (swap all instances of a specific class)
ModuleSwapper.swap_by_type(
    model,
    old_type="StandardAttention",
    new_type="flash",
)

Class Reference

ModuleSwapper

Replace submodules inside an nn.Module at runtime.

All methods are static — no instance state needed.

swap staticmethod

swap(model: Module, target_path: str, new_module: Module, transfer_weights: bool = False) -> nn.Module

Replace a single submodule identified by dotted path.

Parameters:

Name Type Description Default
model Module

The parent model.

required
target_path str

Dotted path (e.g. "evoformer.blocks.0.msa_att").

required
new_module Module

The replacement module.

required
transfer_weights bool

If True, copy matching weights from old → new.

False

Returns:

Type Description
Module

The old (replaced) module.

Raises:

Type Description
KeyError

If target_path does not exist in model.

swap_all staticmethod

swap_all(model: Module, pattern: str, factory: Callable[[str, Module], Module], transfer_weights: bool = False) -> int

Swap all submodules whose full name matches a regex pattern.

Parameters:

Name Type Description Default
model Module

The parent model.

required
pattern str

Regex pattern matched against the full dotted name (e.g. "msa_att" matches "evoformer.blocks.3.msa_att").

required
factory Callable[[str, Module], Module]

factory(name, old_module) → new_module.

required
transfer_weights bool

If True, copy matching weights.

False

Returns:

Type Description
int

Number of modules swapped.

swap_by_type staticmethod

swap_by_type(model: Module, old_type: type, factory: Callable[[str, Module], Module], transfer_weights: bool = False) -> int

Swap all submodules of a given type.

Parameters:

Name Type Description Default
model Module

The parent model.

required
old_type type

Type to match (e.g. nn.MultiheadAttention).

required
factory Callable[[str, Module], Module]

factory(name, old_module) → new_module.

required
transfer_weights bool

If True, copy matching weights.

False

Returns:

Type Description
int

Number of modules swapped.

discover staticmethod

discover(model: Module, pattern: str | None = None, module_type: type | None = None) -> list[tuple[str, nn.Module]]

List modules in the model, optionally filtered by name pattern or type.

Useful for inspecting a model's structure before deciding what to swap.

Returns:

Type Description
list[tuple[str, Module]]

List of (dotted_name, module) tuples.

summary staticmethod

summary(model: Module, pattern: str | None = None) -> str

Human-readable summary of swappable modules.

Returns a formatted string listing module paths, types, and param counts.

swap

Replace a specific module instance with a registered alternative.

ModuleSwapper.swap(
    model,
    module_type="attention",
    name="flash",
    module_path="trunk.layers.0.attn",  # optional: target a specific submodule
)
Parameter Type Default Description
model nn.Module required Model to modify
module_type str required Category: "attention", "block", "embedder", "structure_module"
name str required Registered replacement name
module_path str \| None None Dotted path to a specific submodule to replace
**kwargs dict {} Extra args passed to the new module constructor

swap_all

Replace all modules of a given type throughout the model.

ModuleSwapper.swap_all(model, module_type="attention", name="flash")
Parameter Type Default Description
model nn.Module required Model to modify
module_type str required Category to replace
name str required Registered replacement name
**kwargs dict {} Extra args for the new modules

swap_by_type

Replace all instances of a specific class (regardless of registry category).

ModuleSwapper.swap_by_type(
    model,
    old_type="StandardAttention",
    new_type="flash",
)
Parameter Type Default Description
model nn.Module required Model to modify
old_type str \| type required Class name or class to replace
new_type str required Registered replacement name
**kwargs dict {} Extra args for the new modules

discover

List all swappable modules found in the model, grouped by type.

swappable = ModuleSwapper.discover(model)
# {
#     "attention": ["trunk.layers.0.attn", "trunk.layers.1.attn", ...],
#     "block": ["trunk.layers.0", "trunk.layers.1", ...],
#     "embedder": ["embedder"],
#     "structure_module": ["structure_module"],
# }
Parameter Type Description
model nn.Module Model to inspect

Returns: dict[str, list[str]] mapping module types to lists of dotted paths.


Example: Architecture Ablation

from molfun import MolfunStructureModel
from molfun.modules.swapper import ModuleSwapper

model = MolfunStructureModel.from_pretrained("openfold_v2")

results = {}
for attn_type in ["standard", "flash", "linear", "gated"]:
    ModuleSwapper.swap_all(model, "attention", attn_type)
    output = model.predict("MKFLILLFNILCLFPVLAADNH...")
    results[attn_type] = output.plddt.mean().item()

print(results)