Skip to content

Distributed Training

Molfun supports distributed training via PyTorch's DistributedDataParallel (DDP) and Fully Sharded Data Parallel (FSDP) through the molfun.training.distributed module.

Quick Start

from molfun import MolfunStructureModel

model = MolfunStructureModel.from_pretrained("openfold_v2")

# DDP training (data parallelism)
model.fit(
    train_dataset=ds,
    strategy="full",
    distributed="ddp",
    epochs=10,
)

# FSDP training (model + data parallelism)
model.fit(
    train_dataset=ds,
    strategy="full",
    distributed="fsdp",
    epochs=10,
)

Module Reference

distributed

Distributed training strategies.

Provides BaseDistributedStrategy (ABC) and concrete implementations for DDP and FSDP. These strategies wrap the model and loaders for multi-GPU training, and are passed to FinetuneStrategy.fit() via the distributed parameter.

Design

Follows the Strategy pattern: the distributed backend is decoupled from the fine-tuning strategy (LoRA, Partial, Full). Any combination works::

strategy = LoRAFinetune(rank=8, lr_lora=2e-4)
dist = DDPStrategy(backend="nccl")
strategy.fit(model, train_loader, val_loader, epochs=20, distributed=dist)

The launch() helper handles torch.distributed initialisation and mp.spawn so users don't need to write boilerplate.

BaseDistributedStrategy

Bases: ABC

Interface for distributed training backends.

is_main_process abstractmethod property
is_main_process: bool

True on rank 0 — controls logging, checkpointing, etc.

local_rank abstractmethod property
local_rank: int

Local GPU rank for this process.

setup abstractmethod
setup(rank: int, world_size: int) -> None

Initialise process group for this rank.

wrap_model abstractmethod
wrap_model(model: Module, device: device) -> nn.Module

Wrap model for distributed training. Returns wrapped module.

wrap_loader abstractmethod
wrap_loader(loader: DataLoader, rank: int, world_size: int) -> DataLoader

Replace sampler with a DistributedSampler.

cleanup abstractmethod
cleanup() -> None

Destroy process group.

barrier
barrier() -> None

Synchronise all processes.

DDPStrategy

Bases: BaseDistributedStrategy

Distributed Data Parallel — replicate model on each GPU, synchronise gradients via all-reduce after each backward pass.

Best for: models that fit in a single GPU's memory (most protein ML models: OpenFold ~93M params, ESMFold ~700M on A100 80 GB).

Usage::

dist = DDPStrategy(backend="nccl")
strategy.fit(model, train_loader, val_loader, distributed=dist)

Or with the launcher::

from molfun.training.distributed import launch

def train_fn(rank, world_size, dist):
    model = ...
    strategy = LoRAFinetune(...)
    strategy.fit(model, train_loader, val_loader, distributed=dist)

launch(train_fn, DDPStrategy(backend="nccl"), world_size=4)

FSDPStrategy

Bases: BaseDistributedStrategy

Fully Sharded Data Parallel — shard parameters, gradients, and optimizer state across GPUs.

Best for: models too large for a single GPU, or when you need to maximise batch size per GPU.

Parameters:

Name Type Description Default
backend str

Process group backend (nccl for GPU).

'nccl'
sharding_strategy str

"full" (ZeRO-3), "shard_grad_op" (ZeRO-2), "no_shard" (DDP-equivalent).

'full'
mixed_precision str | None

Enable bf16/fp16 compute. "bf16" or "fp16".

None
cpu_offload bool

Offload parameters to CPU between forward/backward.

False
activation_checkpointing bool

Apply gradient checkpointing to wrapped modules automatically.

False
auto_wrap_min_params int

Minimum parameter count for FSDP auto-wrapping of submodules. 0 wraps everything in one flat group.

100000

launch

launch(fn, distributed: BaseDistributedStrategy, world_size: int | None = None)

Launch a distributed training function across world_size processes.

Each process calls fn(rank, world_size, distributed).

Parameters:

Name Type Description Default
fn

Training function with signature (rank, world_size, dist) -> None.

required
distributed BaseDistributedStrategy

The distributed strategy to use.

required
world_size int | None

Number of GPUs. Defaults to torch.cuda.device_count().

None

Usage::

def train(rank, world_size, dist):
    dist.setup(rank, world_size)
    model = ...
    strategy = LoRAFinetune(...)
    strategy.fit(model, train_loader, distributed=dist)
    dist.cleanup()

launch(train, DDPStrategy(), world_size=4)

DDP -- DistributedDataParallel

Standard data-parallel training. Each GPU holds a full model replica and processes a shard of the data.

Launch

# Using torchrun
torchrun --nproc_per_node=4 train.py

# Using the Molfun CLI
molfun run train.py --gpus 4

Programmatic Usage

from molfun.training.distributed import setup_ddp, cleanup_ddp

# Initialize process group
setup_ddp()

model = MolfunStructureModel.from_pretrained("openfold_v2", device=f"cuda:{local_rank}")
model.fit(
    train_dataset=ds,
    strategy="full",
    distributed="ddp",
    epochs=10,
)

cleanup_ddp()

Configuration

Parameter Type Default Description
backend str "nccl" Communication backend ("nccl", "gloo")
find_unused_parameters bool False Enable if some parameters are not used every forward pass
gradient_as_bucket_view bool True Memory optimization for gradient communication

FSDP -- Fully Sharded Data Parallel

Shards model parameters, gradients, and optimizer state across GPUs. Enables training models that do not fit on a single GPU.

Usage

from molfun.training.distributed import setup_fsdp

model = MolfunStructureModel.from_pretrained("openfold_v2")
model.fit(
    train_dataset=ds,
    strategy="full",
    distributed="fsdp",
    fsdp_config={
        "sharding_strategy": "FULL_SHARD",
        "cpu_offload": False,
        "mixed_precision": True,
    },
    epochs=10,
)

Configuration

Parameter Type Default Description
sharding_strategy str "FULL_SHARD" "FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"
cpu_offload bool False Offload parameters and gradients to CPU
mixed_precision bool True Use mixed precision (fp16/bf16) for communication
auto_wrap_policy str \| None None FSDP wrapping policy for sub-modules
activation_checkpointing bool False Enable gradient checkpointing to save memory

Multi-Node Training

# Node 0 (master)
torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --node_rank=0 \
    --master_addr=10.0.0.1 \
    --master_port=29500 \
    train.py

# Node 1
torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --node_rank=1 \
    --master_addr=10.0.0.1 \
    --master_port=29500 \
    train.py

Recommendations

Scenario Recommended
Model fits on 1 GPU, want faster training DDP
Model does not fit on 1 GPU FSDP (FULL_SHARD)
Moderate memory pressure FSDP (SHARD_GRAD_OP)
Very large models + limited GPUs FSDP + cpu_offload + activation_checkpointing