Skip to content

Attention Modules

Pluggable attention implementations registered in the ATTENTION_REGISTRY. All implementations subclass BaseAttention and share a common forward signature.

Quick Start

from molfun.modules.attention import ATTENTION_REGISTRY

# List available attention mechanisms
print(ATTENTION_REGISTRY.list())
# ["flash", "gated", "linear", "standard"]

# Build an attention module
attn = ATTENTION_REGISTRY.build("flash", num_heads=8, head_dim=64)

# Use in a model via swap
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.swap("attention", "flash")

ATTENTION_REGISTRY

ATTENTION_REGISTRY module-attribute

ATTENTION_REGISTRY = ModuleRegistry('attention')

BaseAttention

BaseAttention

Bases: ABC, Module

Any attention mechanism that maps (Q, K, V) → output.

Implementations must at minimum support: - Multi-head attention with num_heads heads of head_dim each. - An optional additive bias tensor (used by Evoformer pair bias). - An optional boolean mask (True = attend, False = ignore).

The input tensors already have the head dimension split out::

q: [B, H, Lq, D]
k: [B, H, Lk, D]
v: [B, H, Lk, D]

Return shape: [B, H, Lq, D].

forward abstractmethod

forward(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, bias: Tensor | None = None) -> torch.Tensor

Parameters:

Name Type Description Default
q Tensor

Query [B, H, Lq, D]

required
k Tensor

Key [B, H, Lk, D]

required
v Tensor

Value [B, H, Lk, D]

required
mask Tensor | None

Boolean mask [B, 1|H, Lq, Lk]. True = attend.

None
bias Tensor | None

Additive bias [B, 1|H, Lq, Lk] added to logits.

None

Returns:

Type Description
Tensor

Attended output [B, H, Lq, D].

from_config classmethod

from_config(cfg: AttentionConfig, **overrides) -> BaseAttention

Build from a dataclass config, with optional field overrides.

Abstract base class for all attention modules.

from molfun.modules.attention.base import BaseAttention

class MyAttention(BaseAttention):
    def forward(self, q, k, v, mask=None, bias=None):
        ...

    @property
    def num_heads(self) -> int:
        return self._num_heads

    @property
    def head_dim(self) -> int:
        return self._head_dim

Forward Signature

Parameter Type Description
q Tensor Query tensor (B, H, L, D)
k Tensor Key tensor (B, H, L, D)
v Tensor Value tensor (B, H, L, D)
mask Tensor \| None Attention mask (B, 1, L, L) or (B, H, L, L)
bias Tensor \| None Attention bias (e.g., pair representation)

Returns: Tensor of shape (B, H, L, D).


FlashAttention

FlashAttention

Bases: BaseAttention

Flash / memory-efficient attention via F.scaled_dot_product_attention.

Supports the same interface as StandardAttention so it can be swapped in anywhere via the registry or ModuleSwapper.

from_standard classmethod

from_standard(standard_attn: BaseAttention) -> FlashAttention

Convert any BaseAttention to FlashAttention, preserving dimensions.

GPU-optimized attention using Flash Attention 2. Requires a CUDA device and flash-attn package.

from molfun.modules.attention import ATTENTION_REGISTRY

attn = ATTENTION_REGISTRY.build("flash", num_heads=8, head_dim=64)
output = attn(q, k, v, mask=mask)
Parameter Type Default Description
num_heads int required Number of attention heads
head_dim int required Dimension per head
dropout float 0.0 Attention dropout (training only)

StandardAttention

StandardAttention

Bases: BaseAttention

Vanilla scaled dot-product multi-head attention.

Operates on pre-split heads: inputs are [B, H, L, D].

Standard scaled dot-product attention. Works on all devices.

attn = ATTENTION_REGISTRY.build("standard", num_heads=8, head_dim=64)
Parameter Type Default Description
num_heads int required Number of attention heads
head_dim int required Dimension per head
dropout float 0.0 Attention dropout

LinearAttention

LinearAttention

Bases: BaseAttention

Linear attention with ELU+1 kernel feature map.

Computes attention in O(L·D²) instead of O(L²·D), enabling sub-quadratic scaling on very long sequences.

Linear-complexity attention using kernel feature maps. Suitable for very long sequences.

attn = ATTENTION_REGISTRY.build("linear", num_heads=8, head_dim=64)
Parameter Type Default Description
num_heads int required Number of attention heads
head_dim int required Dimension per head
feature_map str "elu" Kernel feature map ("elu", "relu", "favor+")

GatedAttention

GatedAttention

Bases: BaseAttention

Gated multi-head attention: softmax attention with a learned sigmoid gate.

Used in AlphaFold2's Evoformer (gated self-attention with pair bias).

Attention with gating mechanism as used in AlphaFold2 / OpenFold.

attn = ATTENTION_REGISTRY.build("gated", num_heads=8, head_dim=64)
Parameter Type Default Description
num_heads int required Number of attention heads
head_dim int required Dimension per head
dropout float 0.0 Attention dropout

The output is element-wise multiplied by a learned gating vector before the final linear projection.