ModuleSwapper¶
Hot-swap modular components in a live model without rebuilding from scratch. Useful for architecture search, ablation studies, and runtime optimization.
Quick Start¶
from molfun.modules.swapper import ModuleSwapper
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
# Discover swappable modules
swappable = ModuleSwapper.discover(model)
print(swappable)
# {"attention": ["layer_0.attn", ...], "block": ["layer_0", ...], ...}
# Swap a single module
ModuleSwapper.swap(model, module_type="attention", name="flash")
# Swap all modules of a type
ModuleSwapper.swap_all(model, module_type="attention", name="linear")
# Swap by type (swap all instances of a specific class)
ModuleSwapper.swap_by_type(
model,
old_type="StandardAttention",
new_type="flash",
)
Class Reference¶
ModuleSwapper ¶
Replace submodules inside an nn.Module at runtime.
All methods are static — no instance state needed.
swap
staticmethod
¶
swap(model: Module, target_path: str, new_module: Module, transfer_weights: bool = False) -> nn.Module
Replace a single submodule identified by dotted path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The parent model. |
required |
target_path
|
str
|
Dotted path (e.g. "evoformer.blocks.0.msa_att"). |
required |
new_module
|
Module
|
The replacement module. |
required |
transfer_weights
|
bool
|
If True, copy matching weights from old → new. |
False
|
Returns:
| Type | Description |
|---|---|
Module
|
The old (replaced) module. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If target_path does not exist in model. |
swap_all
staticmethod
¶
swap_all(model: Module, pattern: str, factory: Callable[[str, Module], Module], transfer_weights: bool = False) -> int
Swap all submodules whose full name matches a regex pattern.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The parent model. |
required |
pattern
|
str
|
Regex pattern matched against the full dotted name (e.g. "msa_att" matches "evoformer.blocks.3.msa_att"). |
required |
factory
|
Callable[[str, Module], Module]
|
|
required |
transfer_weights
|
bool
|
If True, copy matching weights. |
False
|
Returns:
| Type | Description |
|---|---|
int
|
Number of modules swapped. |
swap_by_type
staticmethod
¶
swap_by_type(model: Module, old_type: type, factory: Callable[[str, Module], Module], transfer_weights: bool = False) -> int
Swap all submodules of a given type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The parent model. |
required |
old_type
|
type
|
Type to match (e.g. |
required |
factory
|
Callable[[str, Module], Module]
|
|
required |
transfer_weights
|
bool
|
If True, copy matching weights. |
False
|
Returns:
| Type | Description |
|---|---|
int
|
Number of modules swapped. |
discover
staticmethod
¶
discover(model: Module, pattern: str | None = None, module_type: type | None = None) -> list[tuple[str, nn.Module]]
List modules in the model, optionally filtered by name pattern or type.
Useful for inspecting a model's structure before deciding what to swap.
Returns:
| Type | Description |
|---|---|
list[tuple[str, Module]]
|
List of (dotted_name, module) tuples. |
summary
staticmethod
¶
Human-readable summary of swappable modules.
Returns a formatted string listing module paths, types, and param counts.
swap¶
Replace a specific module instance with a registered alternative.
ModuleSwapper.swap(
model,
module_type="attention",
name="flash",
module_path="trunk.layers.0.attn", # optional: target a specific submodule
)
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module |
required | Model to modify |
module_type |
str |
required | Category: "attention", "block", "embedder", "structure_module" |
name |
str |
required | Registered replacement name |
module_path |
str \| None |
None |
Dotted path to a specific submodule to replace |
**kwargs |
dict |
{} |
Extra args passed to the new module constructor |
swap_all¶
Replace all modules of a given type throughout the model.
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module |
required | Model to modify |
module_type |
str |
required | Category to replace |
name |
str |
required | Registered replacement name |
**kwargs |
dict |
{} |
Extra args for the new modules |
swap_by_type¶
Replace all instances of a specific class (regardless of registry category).
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module |
required | Model to modify |
old_type |
str \| type |
required | Class name or class to replace |
new_type |
str |
required | Registered replacement name |
**kwargs |
dict |
{} |
Extra args for the new modules |
discover¶
List all swappable modules found in the model, grouped by type.
swappable = ModuleSwapper.discover(model)
# {
# "attention": ["trunk.layers.0.attn", "trunk.layers.1.attn", ...],
# "block": ["trunk.layers.0", "trunk.layers.1", ...],
# "embedder": ["embedder"],
# "structure_module": ["structure_module"],
# }
| Parameter | Type | Description |
|---|---|---|
model |
nn.Module |
Model to inspect |
Returns: dict[str, list[str]] mapping module types to lists of dotted paths.
Example: Architecture Ablation¶
from molfun import MolfunStructureModel
from molfun.modules.swapper import ModuleSwapper
model = MolfunStructureModel.from_pretrained("openfold_v2")
results = {}
for attn_type in ["standard", "flash", "linear", "gated"]:
ModuleSwapper.swap_all(model, "attention", attn_type)
output = model.predict("MKFLILLFNILCLFPVLAADNH...")
results[attn_type] = output.plddt.mean().item()
print(results)