Embedders¶
Embedders convert raw amino acid sequences (and optional features like MSAs)
into the initial single and pair representations consumed by trunk blocks.
All implementations subclass BaseEmbedder and are registered in
EMBEDDER_REGISTRY.
Quick Start¶
from molfun.modules.embedders import EMBEDDER_REGISTRY
# List available embedders
print(EMBEDDER_REGISTRY.list())
# ["esm", "input"]
# Build an embedder
embedder = EMBEDDER_REGISTRY.build("input", c_s=256, c_z=128)
output = embedder(aatype, residue_index)
# Swap embedder in a live model
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.swap("embedder", "esm")
EMBEDDER_REGISTRY¶
BaseEmbedder¶
BaseEmbedder ¶
Bases: ABC, Module
Converts raw features → initial representations for trunk blocks.
Different paradigms: - AF2 InputEmbedder: aatype + relpos → single; aatype outer → pair; MSA feat → msa - ESM Embedder: frozen LM → single repr; optional pair from attention maps - Sequence Embedder: learned embedding table (baseline)
d_pair
abstractmethod
property
¶
Output pair representation dimension (0 if no pair track).
forward
abstractmethod
¶
forward(aatype: Tensor, residue_index: Tensor, msa: Tensor | None = None, msa_mask: Tensor | None = None, **kwargs) -> EmbedderOutput
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
aatype
|
Tensor
|
Residue types [B, L] (int64, 0-20). |
required |
residue_index
|
Tensor
|
Position indices [B, L]. |
required |
msa
|
Tensor | None
|
MSA features [B, N, L, D_msa_feat] (optional). |
None
|
msa_mask
|
Tensor | None
|
MSA mask [B, N, L] (optional). |
None
|
Returns:
| Type | Description |
|---|---|
EmbedderOutput
|
EmbedderOutput with initial single and pair representations. |
Abstract base class for all embedders.
Forward Signature¶
| Parameter | Type | Description |
|---|---|---|
aatype |
Tensor |
Amino acid type indices (B, L) |
residue_index |
Tensor |
Residue position indices (B, L) |
**kwargs |
dict |
Additional features (e.g., MSA, templates) |
Returns: EmbedderOutput
EmbedderOutput¶
EmbedderOutput
dataclass
¶
Standardized output from any embedder.
Dataclass holding embedder outputs.
| Field | Type | Description |
|---|---|---|
s |
Tensor |
Single representation (B, L, c_s) |
z |
Tensor |
Pair representation (B, L, L, c_z) |
InputEmbedder¶
InputEmbedder ¶
Bases: BaseEmbedder
AlphaFold2-style input embedding.
Creates: - single repr from aatype one-hot - pair repr from outer product of aatype + relative position - MSA repr from MSA features (or broadcasts single if no MSA)
Standard input embedder that computes single and pair representations from amino acid types and relative positional encodings.
embedder = EMBEDDER_REGISTRY.build(
"input",
c_s=256,
c_z=128,
max_relative_position=32,
)
output = embedder(aatype, residue_index)
s = output.s # (B, L, 256)
z = output.z # (B, L, L, 128)
| Parameter | Type | Default | Description |
|---|---|---|---|
c_s |
int |
required | Single representation dimension |
c_z |
int |
required | Pair representation dimension |
max_relative_position |
int |
32 |
Maximum relative positional encoding distance |
num_amino_acids |
int |
21 |
Number of amino acid types (including unknown) |
ESMEmbedder¶
ESMEmbedder ¶
Bases: BaseEmbedder
Frozen ESM-2 language model as input embedder.
Produces single representations from ESM hidden states and optionally pair representations from attention maps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
esm_model
|
str
|
ESM model name (e.g. "esm2_t33_650M_UR50D"). |
'esm2_t33_650M_UR50D'
|
d_single
|
int
|
Output single dimension (projects from ESM hidden dim). |
384
|
d_pair
|
int
|
Output pair dimension (0 to disable pair extraction). |
128
|
freeze_lm
|
bool
|
Whether to freeze the language model weights. |
True
|
layer_idx
|
int
|
Which ESM layer to extract representations from (-1 = last). |
-1
|
Embedder that uses ESM (Evolutionary Scale Modeling) language model representations as initial features. Provides richer single representations from pretrained protein language models.
embedder = EMBEDDER_REGISTRY.build(
"esm",
c_s=256,
c_z=128,
esm_model="esm2_t33_650M_UR50D",
freeze_esm=True,
)
output = embedder(aatype, residue_index)
| Parameter | Type | Default | Description |
|---|---|---|---|
c_s |
int |
required | Single representation dimension |
c_z |
int |
required | Pair representation dimension |
esm_model |
str |
"esm2_t33_650M_UR50D" |
ESM model name |
freeze_esm |
bool |
True |
Whether to freeze ESM weights |
layer_idx |
int |
-1 |
Which ESM layer to extract representations from |
Note
The ESM embedder requires the fair-esm package. Install with
pip install fair-esm.