Datasets¶
PyTorch Dataset implementations for structure prediction, affinity
prediction, and streaming workloads.
Quick Start¶
from molfun.data.datasets import StructureDataset, AffinityDataset
# Structure dataset from PDB files
ds = StructureDataset(
data_dir="./pdb_files",
max_length=512,
)
# Affinity dataset
ds = AffinityDataset(
csv_path="./affinity_data.csv",
structure_dir="./structures",
ligand_dir="./ligands",
)
# Use with DataLoader
from torch.utils.data import DataLoader
loader = DataLoader(ds, batch_size=2, num_workers=4)
StructureDataset¶
StructureDataset ¶
Bases: Dataset
Dataset of protein structures for fine-tuning.
Each item returns a dict of features + optional label, ready for the model adapter's forward() method.
Usage
From PDB files (on-the-fly parsing)¶
ds = StructureDataset(pdb_paths=["1abc.cif", "2xyz.cif"])
From pre-computed feature pickles (OpenFold-style)¶
ds = StructureDataset( pdb_paths=["1abc.cif"], features_dir="features/", # contains 1abc.pkl )
With labels¶
ds = StructureDataset( pdb_paths=["1abc.cif"], labels={"1abc": 6.5}, )
sequences
property
¶
Extract sequences from all structures (lazy, caches on first call).
Dataset for protein structure prediction training. Loads PDB/mmCIF files and produces feature dictionaries compatible with the model forward pass.
from molfun.data.datasets import StructureDataset
ds = StructureDataset(
data_dir="./pdb_files",
max_length=512,
crop_strategy="contiguous",
format="mmcif",
)
sample = ds[0]
# {
# "aatype": Tensor (L,),
# "residue_index": Tensor (L,),
# "all_atom_positions": Tensor (L, 37, 3),
# "all_atom_mask": Tensor (L, 37),
# ...
# }
| Parameter | Type | Default | Description |
|---|---|---|---|
data_dir |
str \| Path |
required | Directory containing structure files |
max_length |
int |
512 |
Maximum sequence length (longer chains are cropped) |
crop_strategy |
str |
"contiguous" |
Cropping: "contiguous", "random", "spatial" |
format |
str |
"mmcif" |
Input format: "pdb" or "mmcif" |
msa_dir |
str \| None |
None |
Directory with precomputed MSA files |
transform |
callable \| None |
None |
Optional transform applied to each sample |
AffinityDataset¶
AffinityDataset ¶
Bases: Dataset
Dataset for binding affinity prediction.
Combines protein structures with scalar affinity labels. Each item returns (features_dict, label_tensor).
Usage
From AffinityRecords + PDB directory¶
ds = AffinityDataset.from_records( records=records, pdb_dir="~/.molfun/pdb_cache", )
From CSV + PDB directory¶
ds = AffinityDataset.from_csv( csv_path="data/pdbbind.csv", pdb_dir="pdbs/", )
from_records
classmethod
¶
from_records(records: list[AffinityRecord], pdb_dir: str | Path, fmt: str = 'cif', features_dir: str | Path | None = None, max_seq_len: int = 512, transform: Callable | None = None) -> AffinityDataset
Build dataset from AffinityRecords + a directory of PDB/mmCIF files.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
records
|
list[AffinityRecord]
|
List of AffinityRecord from AffinityFetcher. |
required |
pdb_dir
|
str | Path
|
Directory containing structure files named {pdb_id}.{fmt}. |
required |
fmt
|
str
|
File extension ("cif" or "pdb"). |
'cif'
|
features_dir
|
str | Path | None
|
Optional pre-computed features directory. |
None
|
max_seq_len
|
int
|
Crop sequences longer than this. |
512
|
transform
|
Callable | None
|
Optional transform on feature dicts. |
None
|
from_csv
classmethod
¶
from_csv(csv_path: str | Path, pdb_dir: str | Path, fmt: str = 'cif', pdb_col: str = 'pdb_id', affinity_col: str = 'affinity', features_dir: str | Path | None = None, max_seq_len: int = 512, transform: Callable | None = None) -> AffinityDataset
Build dataset directly from a CSV file + PDB directory.
collate_fn
staticmethod
¶
Use this as DataLoader collate_fn for variable-length structures.
Dataset for protein-ligand binding affinity prediction.
from molfun.data.datasets import AffinityDataset
ds = AffinityDataset(
csv_path="./affinity_labels.csv",
structure_dir="./structures",
ligand_dir="./ligands",
label_column="pKd",
)
sample = ds[0]
# {
# "aatype": Tensor,
# "residue_index": Tensor,
# "ligand_features": Tensor,
# "affinity_label": Tensor,
# ...
# }
| Parameter | Type | Default | Description |
|---|---|---|---|
csv_path |
str \| Path |
required | CSV with PDB IDs and affinity labels |
structure_dir |
str \| Path |
required | Directory with protein structure files |
ligand_dir |
str \| Path |
required | Directory with ligand files (SDF) |
label_column |
str |
"affinity" |
Column name for affinity labels |
max_length |
int |
512 |
Maximum protein sequence length |
transform |
callable \| None |
None |
Optional transform |
StreamingStructureDataset¶
StreamingStructureDataset ¶
Bases: IterableDataset
IterableDataset that streams protein structures from any fsspec filesystem.
Reads an index CSV (pdb_id, affinity, ...) and lazily loads structure files (.pkl or .cif) on demand. No full download needed.
Multi-worker safe: each worker processes a disjoint shard of the index.
__init__ ¶
__init__(index_path: str, structures_prefix: str, pdb_col: str = 'pdb_id', label_col: str = 'affinity', fmt: str = 'pkl', shuffle_buffer: int = 0, max_seq_len: int = 512, transform: Callable | None = None, storage_options: dict | None = None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index_path
|
str
|
Path to CSV index file (local or remote). |
required |
structures_prefix
|
str
|
Directory/prefix containing structure files. |
required |
pdb_col
|
str
|
Column name for PDB ID in the CSV. |
'pdb_id'
|
label_col
|
str
|
Column name for the label (affinity). |
'affinity'
|
fmt
|
str
|
File format for structures: "pkl" (pre-computed features) or "cif"/"pdb" (raw structures, requires BioPython). |
'pkl'
|
shuffle_buffer
|
int
|
If > 0, maintains a buffer of this size and yields samples randomly from it (reservoir sampling). |
0
|
max_seq_len
|
int
|
Crop sequences longer than this. |
512
|
transform
|
Callable | None
|
Optional transform on feature dicts. |
None
|
storage_options
|
dict | None
|
fsspec options (e.g. endpoint_url for MinIO). |
None
|
An IterableDataset for large-scale training that streams data from
disk without loading everything into memory.
from molfun.data.datasets import StreamingStructureDataset
ds = StreamingStructureDataset(
data_dir="./large_pdb_dataset",
max_length=512,
shuffle_buffer=1000,
)
loader = DataLoader(ds, batch_size=2, num_workers=4)
for batch in loader:
...
| Parameter | Type | Default | Description |
|---|---|---|---|
data_dir |
str \| Path |
required | Directory containing structure files |
max_length |
int |
512 |
Maximum sequence length |
shuffle_buffer |
int |
1000 |
Buffer size for streaming shuffle |
format |
str |
"mmcif" |
Input format |
seed |
int \| None |
None |
Random seed for reproducible shuffling |