Attention Modules¶
Pluggable attention implementations registered in the ATTENTION_REGISTRY.
All implementations subclass BaseAttention and share a common forward
signature.
Quick Start¶
from molfun.modules.attention import ATTENTION_REGISTRY
# List available attention mechanisms
print(ATTENTION_REGISTRY.list())
# ["flash", "gated", "linear", "standard"]
# Build an attention module
attn = ATTENTION_REGISTRY.build("flash", num_heads=8, head_dim=64)
# Use in a model via swap
from molfun import MolfunStructureModel
model = MolfunStructureModel.from_pretrained("openfold_v2")
model.swap("attention", "flash")
ATTENTION_REGISTRY¶
BaseAttention¶
BaseAttention ¶
Bases: ABC, Module
Any attention mechanism that maps (Q, K, V) → output.
Implementations must at minimum support:
- Multi-head attention with num_heads heads of head_dim each.
- An optional additive bias tensor (used by Evoformer pair bias).
- An optional boolean mask (True = attend, False = ignore).
The input tensors already have the head dimension split out::
q: [B, H, Lq, D]
k: [B, H, Lk, D]
v: [B, H, Lk, D]
Return shape: [B, H, Lq, D].
forward
abstractmethod
¶
forward(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, bias: Tensor | None = None) -> torch.Tensor
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Tensor
|
Query [B, H, Lq, D] |
required |
k
|
Tensor
|
Key [B, H, Lk, D] |
required |
v
|
Tensor
|
Value [B, H, Lk, D] |
required |
mask
|
Tensor | None
|
Boolean mask [B, 1|H, Lq, Lk]. True = attend. |
None
|
bias
|
Tensor | None
|
Additive bias [B, 1|H, Lq, Lk] added to logits. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Attended output [B, H, Lq, D]. |
from_config
classmethod
¶
Build from a dataclass config, with optional field overrides.
Abstract base class for all attention modules.
from molfun.modules.attention.base import BaseAttention
class MyAttention(BaseAttention):
def forward(self, q, k, v, mask=None, bias=None):
...
@property
def num_heads(self) -> int:
return self._num_heads
@property
def head_dim(self) -> int:
return self._head_dim
Forward Signature¶
| Parameter | Type | Description |
|---|---|---|
q |
Tensor |
Query tensor (B, H, L, D) |
k |
Tensor |
Key tensor (B, H, L, D) |
v |
Tensor |
Value tensor (B, H, L, D) |
mask |
Tensor \| None |
Attention mask (B, 1, L, L) or (B, H, L, L) |
bias |
Tensor \| None |
Attention bias (e.g., pair representation) |
Returns: Tensor of shape (B, H, L, D).
FlashAttention¶
FlashAttention ¶
Bases: BaseAttention
Flash / memory-efficient attention via F.scaled_dot_product_attention.
Supports the same interface as StandardAttention so it can be swapped in anywhere via the registry or ModuleSwapper.
from_standard
classmethod
¶
Convert any BaseAttention to FlashAttention, preserving dimensions.
GPU-optimized attention using Flash Attention 2. Requires a CUDA device
and flash-attn package.
from molfun.modules.attention import ATTENTION_REGISTRY
attn = ATTENTION_REGISTRY.build("flash", num_heads=8, head_dim=64)
output = attn(q, k, v, mask=mask)
| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads |
int |
required | Number of attention heads |
head_dim |
int |
required | Dimension per head |
dropout |
float |
0.0 |
Attention dropout (training only) |
StandardAttention¶
StandardAttention ¶
Bases: BaseAttention
Vanilla scaled dot-product multi-head attention.
Operates on pre-split heads: inputs are [B, H, L, D].
Standard scaled dot-product attention. Works on all devices.
| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads |
int |
required | Number of attention heads |
head_dim |
int |
required | Dimension per head |
dropout |
float |
0.0 |
Attention dropout |
LinearAttention¶
LinearAttention ¶
Bases: BaseAttention
Linear attention with ELU+1 kernel feature map.
Computes attention in O(L·D²) instead of O(L²·D), enabling sub-quadratic scaling on very long sequences.
Linear-complexity attention using kernel feature maps. Suitable for very long sequences.
| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads |
int |
required | Number of attention heads |
head_dim |
int |
required | Dimension per head |
feature_map |
str |
"elu" |
Kernel feature map ("elu", "relu", "favor+") |
GatedAttention¶
GatedAttention ¶
Bases: BaseAttention
Gated multi-head attention: softmax attention with a learned sigmoid gate.
Used in AlphaFold2's Evoformer (gated self-attention with pair bias).
Attention with gating mechanism as used in AlphaFold2 / OpenFold.
| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads |
int |
required | Number of attention heads |
head_dim |
int |
required | Dimension per head |
dropout |
float |
0.0 |
Attention dropout |
The output is element-wise multiplied by a learned gating vector before the final linear projection.