Skip to content

Trunk Blocks

Trunk blocks form the core of the model architecture. Each block processes single and pair representations through attention and feed-forward layers. All implementations subclass BaseBlock and are registered in BLOCK_REGISTRY.

Quick Start

from molfun.modules.blocks import BLOCK_REGISTRY

# List available blocks
print(BLOCK_REGISTRY.list())
# ["evoformer", "pairformer", "simple_transformer"]

# Build a block
block = BLOCK_REGISTRY.build("pairformer", c_s=256, c_z=128, num_heads=8)

# Swap blocks in a live model
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.swap_all("block", "pairformer")

BLOCK_REGISTRY

BLOCK_REGISTRY module-attribute

BLOCK_REGISTRY = ModuleRegistry('block')

BaseBlock

BaseBlock

Bases: ABC, Module

A single repeating block of the trunk.

Subclasses implement the specific architecture (Evoformer, Pairformer, simple transformer, etc.) but all accept and return representations through the same BlockOutput interface.

d_single abstractmethod property

d_single: int

Single/MSA representation dimension.

d_pair abstractmethod property

d_pair: int

Pair representation dimension (0 if single-track only).

forward abstractmethod

forward(single: Tensor, pair: Tensor | None = None, mask: Tensor | None = None, pair_mask: Tensor | None = None) -> BlockOutput

Process representations through one block.

Parameters:

Name Type Description Default
single Tensor

Per-token features. Evoformer: MSA repr [B, N_msa, L, D_msa] Pairformer/Simple: single repr [B, L, D_single]

required
pair Tensor | None

Pairwise features [B, L, L, D_pair]. None for single-track models.

None
mask Tensor | None

Token mask [B, L] or [B, N, L].

None
pair_mask Tensor | None

Pair mask [B, L, L].

None

Returns:

Type Description
BlockOutput

BlockOutput with updated single and pair representations.

Abstract base class for trunk blocks.

Forward Signature

Parameter Type Description
s Tensor Single representation (B, L, c_s)
z Tensor Pair representation (B, L, L, c_z)
mask Tensor \| None Sequence mask (B, L)

Returns: BlockOutput with .s (single) and .z (pair) tensors.


BlockOutput

BlockOutput dataclass

Standardized output from a trunk block.

Dataclass holding block output tensors.

Field Type Description
s Tensor Updated single representation
z Tensor Updated pair representation

Pairformer

PairformerBlock

Bases: BaseBlock

Pairformer block with pluggable attention.

Unlike EvoformerBlock, operates on single [B, L, D] instead of MSA [B, N, L, D]. This is the AF3/Protenix paradigm where MSA information is compressed into a single representation before the trunk.

Pairformer block with triangular multiplicative updates and attention.

block = BLOCK_REGISTRY.build(
    "pairformer",
    c_s=256,
    c_z=128,
    num_heads=8,
    dropout=0.0,
)
output = block(s, z, mask=mask)
Parameter Type Default Description
c_s int required Single representation dimension
c_z int required Pair representation dimension
num_heads int 8 Number of attention heads
dropout float 0.0 Dropout probability

Evoformer

EvoformerBlock

Bases: BaseBlock

Evoformer block with pluggable attention.

The attention mechanism used for MSA row/column attention can be swapped by passing attention_cls (e.g. FlashAttention, GatedAttention).

The Evoformer block from AlphaFold2, with MSA row/column attention and outer product mean for pair updates.

block = BLOCK_REGISTRY.build(
    "evoformer",
    c_s=256,
    c_z=128,
    c_m=256,
    num_heads=8,
)
output = block(s, z, mask=mask)
Parameter Type Default Description
c_s int required Single representation dimension
c_z int required Pair representation dimension
c_m int 256 MSA representation dimension
num_heads int 8 Number of attention heads
dropout float 0.0 Dropout probability

SimpleTransformer

SimpleTransformerBlock

Bases: BaseBlock

Single-track transformer block with pluggable attention.

No pair representation — designed for ESMFold-style models or as a lightweight baseline for ablation studies.

A lightweight transformer block with standard self-attention and FFN. Useful for baselines and property prediction heads.

block = BLOCK_REGISTRY.build(
    "simple_transformer",
    c_s=256,
    num_heads=8,
    ffn_dim=1024,
)
Parameter Type Default Description
c_s int required Input/output dimension
num_heads int 8 Number of attention heads
ffn_dim int \| None None FFN hidden dimension (defaults to 4 * c_s)
dropout float 0.0 Dropout probability