Distributed Training¶
Molfun supports distributed training via PyTorch's DistributedDataParallel
(DDP) and Fully Sharded Data Parallel (FSDP) through the
molfun.training.distributed module.
Quick Start¶
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
# DDP training (data parallelism)
model.fit(
train_dataset=ds,
strategy="full",
distributed="ddp",
epochs=10,
)
# FSDP training (model + data parallelism)
model.fit(
train_dataset=ds,
strategy="full",
distributed="fsdp",
epochs=10,
)
Module Reference¶
distributed ¶
Distributed training strategies.
Provides BaseDistributedStrategy (ABC) and concrete implementations
for DDP and FSDP. These strategies wrap the model and loaders for
multi-GPU training, and are passed to FinetuneStrategy.fit() via the
distributed parameter.
Design¶
Follows the Strategy pattern: the distributed backend is decoupled from the fine-tuning strategy (LoRA, Partial, Full). Any combination works::
strategy = LoRAFinetune(rank=8, lr_lora=2e-4)
dist = DDPStrategy(backend="nccl")
strategy.fit(model, train_loader, val_loader, epochs=20, distributed=dist)
The launch() helper handles torch.distributed initialisation and
mp.spawn so users don't need to write boilerplate.
BaseDistributedStrategy ¶
Bases: ABC
Interface for distributed training backends.
is_main_process
abstractmethod
property
¶
True on rank 0 — controls logging, checkpointing, etc.
setup
abstractmethod
¶
Initialise process group for this rank.
wrap_model
abstractmethod
¶
Wrap model for distributed training. Returns wrapped module.
wrap_loader
abstractmethod
¶
Replace sampler with a DistributedSampler.
DDPStrategy ¶
Bases: BaseDistributedStrategy
Distributed Data Parallel — replicate model on each GPU, synchronise gradients via all-reduce after each backward pass.
Best for: models that fit in a single GPU's memory (most protein ML models: OpenFold ~93M params, ESMFold ~700M on A100 80 GB).
Usage::
dist = DDPStrategy(backend="nccl")
strategy.fit(model, train_loader, val_loader, distributed=dist)
Or with the launcher::
from molfun.training.distributed import launch
def train_fn(rank, world_size, dist):
model = ...
strategy = LoRAFinetune(...)
strategy.fit(model, train_loader, val_loader, distributed=dist)
launch(train_fn, DDPStrategy(backend="nccl"), world_size=4)
FSDPStrategy ¶
Bases: BaseDistributedStrategy
Fully Sharded Data Parallel — shard parameters, gradients, and optimizer state across GPUs.
Best for: models too large for a single GPU, or when you need to maximise batch size per GPU.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backend
|
str
|
Process group backend ( |
'nccl'
|
sharding_strategy
|
str
|
|
'full'
|
mixed_precision
|
str | None
|
Enable bf16/fp16 compute. |
None
|
cpu_offload
|
bool
|
Offload parameters to CPU between forward/backward. |
False
|
activation_checkpointing
|
bool
|
Apply gradient checkpointing to wrapped modules automatically. |
False
|
auto_wrap_min_params
|
int
|
Minimum parameter count for FSDP auto-wrapping
of submodules. |
100000
|
launch ¶
Launch a distributed training function across world_size processes.
Each process calls fn(rank, world_size, distributed).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Training function with signature |
required | |
distributed
|
BaseDistributedStrategy
|
The distributed strategy to use. |
required |
world_size
|
int | None
|
Number of GPUs. Defaults to |
None
|
Usage::
def train(rank, world_size, dist):
dist.setup(rank, world_size)
model = ...
strategy = LoRAFinetune(...)
strategy.fit(model, train_loader, distributed=dist)
dist.cleanup()
launch(train, DDPStrategy(), world_size=4)
DDP -- DistributedDataParallel¶
Standard data-parallel training. Each GPU holds a full model replica and processes a shard of the data.
Launch¶
# Using torchrun
torchrun --nproc_per_node=4 train.py
# Using the Molfun CLI
molfun run train.py --gpus 4
Programmatic Usage¶
from molfun.training.distributed import setup_ddp, cleanup_ddp
# Initialize process group
setup_ddp()
model = MolfunStructureModel.from_pretrained("openfold_v2", device=f"cuda:{local_rank}")
model.fit(
train_dataset=ds,
strategy="full",
distributed="ddp",
epochs=10,
)
cleanup_ddp()
Configuration¶
| Parameter | Type | Default | Description |
|---|---|---|---|
backend |
str |
"nccl" |
Communication backend ("nccl", "gloo") |
find_unused_parameters |
bool |
False |
Enable if some parameters are not used every forward pass |
gradient_as_bucket_view |
bool |
True |
Memory optimization for gradient communication |
FSDP -- Fully Sharded Data Parallel¶
Shards model parameters, gradients, and optimizer state across GPUs. Enables training models that do not fit on a single GPU.
Usage¶
from molfun.training.distributed import setup_fsdp
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.fit(
train_dataset=ds,
strategy="full",
distributed="fsdp",
fsdp_config={
"sharding_strategy": "FULL_SHARD",
"cpu_offload": False,
"mixed_precision": True,
},
epochs=10,
)
Configuration¶
| Parameter | Type | Default | Description |
|---|---|---|---|
sharding_strategy |
str |
"FULL_SHARD" |
"FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD" |
cpu_offload |
bool |
False |
Offload parameters and gradients to CPU |
mixed_precision |
bool |
True |
Use mixed precision (fp16/bf16) for communication |
auto_wrap_policy |
str \| None |
None |
FSDP wrapping policy for sub-modules |
activation_checkpointing |
bool |
False |
Enable gradient checkpointing to save memory |
Multi-Node Training¶
# Node 0 (master)
torchrun \
--nnodes=2 \
--nproc_per_node=4 \
--node_rank=0 \
--master_addr=10.0.0.1 \
--master_port=29500 \
train.py
# Node 1
torchrun \
--nnodes=2 \
--nproc_per_node=4 \
--node_rank=1 \
--master_addr=10.0.0.1 \
--master_port=29500 \
train.py
Recommendations¶
| Scenario | Recommended |
|---|---|
| Model fits on 1 GPU, want faster training | DDP |
| Model does not fit on 1 GPU | FSDP (FULL_SHARD) |
| Moderate memory pressure | FSDP (SHARD_GRAD_OP) |
| Very large models + limited GPUs | FSDP + cpu_offload + activation_checkpointing |