Skip to content

Datasets

PyTorch Dataset implementations for structure prediction, affinity prediction, and streaming workloads.

Quick Start

from molfun.data.datasets import StructureDataset, AffinityDataset

# Structure dataset from PDB files
ds = StructureDataset(
    data_dir="./pdb_files",
    max_length=512,
)

# Affinity dataset
ds = AffinityDataset(
    csv_path="./affinity_data.csv",
    structure_dir="./structures",
    ligand_dir="./ligands",
)

# Use with DataLoader
from torch.utils.data import DataLoader
loader = DataLoader(ds, batch_size=2, num_workers=4)

StructureDataset

StructureDataset

Bases: Dataset

Dataset of protein structures for fine-tuning.

Each item returns a dict of features + optional label, ready for the model adapter's forward() method.

Usage

From PDB files (on-the-fly parsing)

ds = StructureDataset(pdb_paths=["1abc.cif", "2xyz.cif"])

From pre-computed feature pickles (OpenFold-style)

ds = StructureDataset( pdb_paths=["1abc.cif"], features_dir="features/", # contains 1abc.pkl )

With labels

ds = StructureDataset( pdb_paths=["1abc.cif"], labels={"1abc": 6.5}, )

sequences property

sequences: list[str]

Extract sequences from all structures (lazy, caches on first call).

Dataset for protein structure prediction training. Loads PDB/mmCIF files and produces feature dictionaries compatible with the model forward pass.

from molfun.data.datasets import StructureDataset

ds = StructureDataset(
    data_dir="./pdb_files",
    max_length=512,
    crop_strategy="contiguous",
    format="mmcif",
)

sample = ds[0]
# {
#     "aatype": Tensor (L,),
#     "residue_index": Tensor (L,),
#     "all_atom_positions": Tensor (L, 37, 3),
#     "all_atom_mask": Tensor (L, 37),
#     ...
# }
Parameter Type Default Description
data_dir str \| Path required Directory containing structure files
max_length int 512 Maximum sequence length (longer chains are cropped)
crop_strategy str "contiguous" Cropping: "contiguous", "random", "spatial"
format str "mmcif" Input format: "pdb" or "mmcif"
msa_dir str \| None None Directory with precomputed MSA files
transform callable \| None None Optional transform applied to each sample

AffinityDataset

AffinityDataset

Bases: Dataset

Dataset for binding affinity prediction.

Combines protein structures with scalar affinity labels. Each item returns (features_dict, label_tensor).

Usage

From AffinityRecords + PDB directory

ds = AffinityDataset.from_records( records=records, pdb_dir="~/.molfun/pdb_cache", )

From CSV + PDB directory

ds = AffinityDataset.from_csv( csv_path="data/pdbbind.csv", pdb_dir="pdbs/", )

from_records classmethod

from_records(records: list[AffinityRecord], pdb_dir: str | Path, fmt: str = 'cif', features_dir: str | Path | None = None, max_seq_len: int = 512, transform: Callable | None = None) -> AffinityDataset

Build dataset from AffinityRecords + a directory of PDB/mmCIF files.

Parameters:

Name Type Description Default
records list[AffinityRecord]

List of AffinityRecord from AffinityFetcher.

required
pdb_dir str | Path

Directory containing structure files named {pdb_id}.{fmt}.

required
fmt str

File extension ("cif" or "pdb").

'cif'
features_dir str | Path | None

Optional pre-computed features directory.

None
max_seq_len int

Crop sequences longer than this.

512
transform Callable | None

Optional transform on feature dicts.

None

from_csv classmethod

from_csv(csv_path: str | Path, pdb_dir: str | Path, fmt: str = 'cif', pdb_col: str = 'pdb_id', affinity_col: str = 'affinity', features_dir: str | Path | None = None, max_seq_len: int = 512, transform: Callable | None = None) -> AffinityDataset

Build dataset directly from a CSV file + PDB directory.

collate_fn staticmethod

collate_fn(batch)

Use this as DataLoader collate_fn for variable-length structures.

Dataset for protein-ligand binding affinity prediction.

from molfun.data.datasets import AffinityDataset

ds = AffinityDataset(
    csv_path="./affinity_labels.csv",
    structure_dir="./structures",
    ligand_dir="./ligands",
    label_column="pKd",
)

sample = ds[0]
# {
#     "aatype": Tensor,
#     "residue_index": Tensor,
#     "ligand_features": Tensor,
#     "affinity_label": Tensor,
#     ...
# }
Parameter Type Default Description
csv_path str \| Path required CSV with PDB IDs and affinity labels
structure_dir str \| Path required Directory with protein structure files
ligand_dir str \| Path required Directory with ligand files (SDF)
label_column str "affinity" Column name for affinity labels
max_length int 512 Maximum protein sequence length
transform callable \| None None Optional transform

StreamingStructureDataset

StreamingStructureDataset

Bases: IterableDataset

IterableDataset that streams protein structures from any fsspec filesystem.

Reads an index CSV (pdb_id, affinity, ...) and lazily loads structure files (.pkl or .cif) on demand. No full download needed.

Multi-worker safe: each worker processes a disjoint shard of the index.

__init__

__init__(index_path: str, structures_prefix: str, pdb_col: str = 'pdb_id', label_col: str = 'affinity', fmt: str = 'pkl', shuffle_buffer: int = 0, max_seq_len: int = 512, transform: Callable | None = None, storage_options: dict | None = None)

Parameters:

Name Type Description Default
index_path str

Path to CSV index file (local or remote).

required
structures_prefix str

Directory/prefix containing structure files.

required
pdb_col str

Column name for PDB ID in the CSV.

'pdb_id'
label_col str

Column name for the label (affinity).

'affinity'
fmt str

File format for structures: "pkl" (pre-computed features) or "cif"/"pdb" (raw structures, requires BioPython).

'pkl'
shuffle_buffer int

If > 0, maintains a buffer of this size and yields samples randomly from it (reservoir sampling).

0
max_seq_len int

Crop sequences longer than this.

512
transform Callable | None

Optional transform on feature dicts.

None
storage_options dict | None

fsspec options (e.g. endpoint_url for MinIO).

None

An IterableDataset for large-scale training that streams data from disk without loading everything into memory.

from molfun.data.datasets import StreamingStructureDataset

ds = StreamingStructureDataset(
    data_dir="./large_pdb_dataset",
    max_length=512,
    shuffle_buffer=1000,
)

loader = DataLoader(ds, batch_size=2, num_workers=4)
for batch in loader:
    ...
Parameter Type Default Description
data_dir str \| Path required Directory containing structure files
max_length int 512 Maximum sequence length
shuffle_buffer int 1000 Buffer size for streaming shuffle
format str "mmcif" Input format
seed int \| None None Random seed for reproducible shuffling