Skip to content

Embedders

Embedders convert raw amino acid sequences (and optional features like MSAs) into the initial single and pair representations consumed by trunk blocks. All implementations subclass BaseEmbedder and are registered in EMBEDDER_REGISTRY.

Quick Start

from molfun.modules.embedders import EMBEDDER_REGISTRY

# List available embedders
print(EMBEDDER_REGISTRY.list())
# ["esm", "input"]

# Build an embedder
embedder = EMBEDDER_REGISTRY.build("input", c_s=256, c_z=128)
output = embedder(aatype, residue_index)

# Swap embedder in a live model
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.swap("embedder", "esm")

EMBEDDER_REGISTRY

EMBEDDER_REGISTRY module-attribute

EMBEDDER_REGISTRY = ModuleRegistry('embedder')

BaseEmbedder

BaseEmbedder

Bases: ABC, Module

Converts raw features → initial representations for trunk blocks.

Different paradigms: - AF2 InputEmbedder: aatype + relpos → single; aatype outer → pair; MSA feat → msa - ESM Embedder: frozen LM → single repr; optional pair from attention maps - Sequence Embedder: learned embedding table (baseline)

d_single abstractmethod property

d_single: int

Output single representation dimension.

d_pair abstractmethod property

d_pair: int

Output pair representation dimension (0 if no pair track).

forward abstractmethod

forward(aatype: Tensor, residue_index: Tensor, msa: Tensor | None = None, msa_mask: Tensor | None = None, **kwargs) -> EmbedderOutput

Parameters:

Name Type Description Default
aatype Tensor

Residue types [B, L] (int64, 0-20).

required
residue_index Tensor

Position indices [B, L].

required
msa Tensor | None

MSA features [B, N, L, D_msa_feat] (optional).

None
msa_mask Tensor | None

MSA mask [B, N, L] (optional).

None

Returns:

Type Description
EmbedderOutput

EmbedderOutput with initial single and pair representations.

Abstract base class for all embedders.

Forward Signature

Parameter Type Description
aatype Tensor Amino acid type indices (B, L)
residue_index Tensor Residue position indices (B, L)
**kwargs dict Additional features (e.g., MSA, templates)

Returns: EmbedderOutput


EmbedderOutput

EmbedderOutput dataclass

Standardized output from any embedder.

Dataclass holding embedder outputs.

Field Type Description
s Tensor Single representation (B, L, c_s)
z Tensor Pair representation (B, L, L, c_z)

InputEmbedder

InputEmbedder

Bases: BaseEmbedder

AlphaFold2-style input embedding.

Creates: - single repr from aatype one-hot - pair repr from outer product of aatype + relative position - MSA repr from MSA features (or broadcasts single if no MSA)

Standard input embedder that computes single and pair representations from amino acid types and relative positional encodings.

embedder = EMBEDDER_REGISTRY.build(
    "input",
    c_s=256,
    c_z=128,
    max_relative_position=32,
)

output = embedder(aatype, residue_index)
s = output.s   # (B, L, 256)
z = output.z   # (B, L, L, 128)
Parameter Type Default Description
c_s int required Single representation dimension
c_z int required Pair representation dimension
max_relative_position int 32 Maximum relative positional encoding distance
num_amino_acids int 21 Number of amino acid types (including unknown)

ESMEmbedder

ESMEmbedder

Bases: BaseEmbedder

Frozen ESM-2 language model as input embedder.

Produces single representations from ESM hidden states and optionally pair representations from attention maps.

Parameters:

Name Type Description Default
esm_model str

ESM model name (e.g. "esm2_t33_650M_UR50D").

'esm2_t33_650M_UR50D'
d_single int

Output single dimension (projects from ESM hidden dim).

384
d_pair int

Output pair dimension (0 to disable pair extraction).

128
freeze_lm bool

Whether to freeze the language model weights.

True
layer_idx int

Which ESM layer to extract representations from (-1 = last).

-1

Embedder that uses ESM (Evolutionary Scale Modeling) language model representations as initial features. Provides richer single representations from pretrained protein language models.

embedder = EMBEDDER_REGISTRY.build(
    "esm",
    c_s=256,
    c_z=128,
    esm_model="esm2_t33_650M_UR50D",
    freeze_esm=True,
)

output = embedder(aatype, residue_index)
Parameter Type Default Description
c_s int required Single representation dimension
c_z int required Pair representation dimension
esm_model str "esm2_t33_650M_UR50D" ESM model name
freeze_esm bool True Whether to freeze ESM weights
layer_idx int -1 Which ESM layer to extract representations from

Note

The ESM embedder requires the fair-esm package. Install with pip install fair-esm.