You deployed your first LLM. Inference is 10x slower than you expected, and your GPU utilization hovers around 30%. You throw money at bigger GPUs, but throughput barely improves. What went wrong?
The bottleneck isn't compute. It's memory bandwidth.
I've seen this pattern repeatedly when teams first deploy transformers at scale. They profile their model expecting to find slow matrix multiplies, but instead discover that attention spends most of its time waiting for data to move between memory hierarchies—specifically, reading from and writing to HBM (High Bandwidth Memory), the GPU's main memory. An A100 can perform 312 trillion FP16 operations per second, but HBM can only transfer 2 terabytes of data per second. When your algorithm requires moving N^2 bytes to compute N^2 operations, you're leaving most of that compute power idle.
This article explains why attention is memory-bound and how three techniques fix it: KV caching for eliminating redundant computation, and FlashAttention (1, 2, and 3) for making the remaining computation IO-efficient. By the end, you'll understand the memory hierarchy that governs GPU performance, the tiling tricks that make FlashAttention work, and when to use each optimization.
What you'll learn:
Prerequisites: Familiarity with the transformer attention mechanism (softmax(QK^T/sqrt(d))V), basic understanding of GPU architecture (what CUDA cores and memory are), and comfort reading PyTorch code.
A note on scope: This article focuses on the algorithmic ideas behind attention optimization. I've simplified some GPU architecture details where the core insight doesn't depend on them. For the full implementation details, the FlashAttention papers (linked at the end) are excellent.
With that context, let's start with practical guidance—then build up to the theory.
Before diving into the details, here's what to use when:
| Situation | Recommendation |
|---|---|
| Any production inference | Enable KV cache (always) |
| PyTorch 2.0+ | Use scaled_dot_product_attention (includes FlashAttention) |
| Training or long sequences | Use FlashAttention library directly |
| H100 GPUs | Use FlashAttention-3 for best performance |
| Serving many concurrent requests | Add PagedAttention (vLLM) |
| Need to reduce KV cache memory | Consider GQA or MQA architectures |
Key Insight
KV caching and FlashAttention solve different problems. KV caching eliminates redundant computation across generation steps. FlashAttention makes each attention computation more memory-efficient. Use both together for maximum performance.
To understand why your GPU sits idle at 30% utilization, we need to look at what attention actually does to memory. Let's start with why attention is slow.
The standard attention formula looks innocent enough:
(The √d scaling is critical: without it, dot products grow with dimension, making softmax nearly one-hot. The scaling keeps attention scores in a reasonable range regardless of head dimension.)
For a sequence of length N with head dimension d, this requires:
Total compute: O(N^2 d). That's a lot of operations, but modern GPUs are fast at matrix multiplication. An A100 can do 312 TFLOPs (trillion floating-point operations) of FP16 math per second.
Here's the problem: the standard implementation materializes (allocates and stores) the N x N attention matrix in GPU memory. For a sequence length of 8,192 with 32 attention heads, that's:
And you need to read this matrix back to multiply by V. With 2 TB/s memory bandwidth, just moving this data takes 2+ milliseconds per layer. A 32-layer model spends 70+ milliseconds just on memory transfers for attention alone.
The concept that explains this is arithmetic intensity: the ratio of compute operations to bytes moved.
For a GPU to be compute-bound (using all its FLOPs), the arithmetic intensity of your algorithm must exceed the GPU's balance point—the minimum FLOPs per byte needed to saturate compute. Below this threshold, memory bandwidth is the bottleneck; above it, compute can run at full speed.
The balance point is simply: Peak FLOPs ÷ Memory Bandwidth.
| GPU | Memory Bandwidth | Peak FP16 | Balance Point |
|---|---|---|---|
| A100 | 2.0 TB/s | 312 TFLOPs | 312 ÷ 2 = 156 FLOPs/byte |
| H100 | 3.35 TB/s | 989 TFLOPs | 989 ÷ 3.35 ≈ 295 FLOPs/byte |
Standard attention's arithmetic intensity for the softmax step? About 1 FLOP per byte. Let's see where that comes from:
Softmax over N elements:
- FLOPs: ~4N (find max, subtract max, compute exp, divide by sum)
- Memory: Read N values (2N bytes in FP16) + Write N values (2N bytes)
- Arithmetic Intensity: 4N FLOPs / 4N bytes = 1 FLOP/byte
We need 156x more compute per byte to saturate an A100. Softmax is dramatically below this threshold—it's almost entirely memory-bound.
Why Matrix Multiply is Different
Matrix multiplication has high arithmetic intensity because of data reuse: each element is used in multiple operations. Multiplying two N×N matrices requires O(N³) FLOPs but only O(N²) memory accesses—each element participates in N different multiply-accumulate operations. That's O(N) FLOPs per byte, which exceeds the balance point for large N. This is why matrix multiply runs at 70-80% GPU utilization while attention struggles at 30%.
Understanding GPU memory hierarchy explains why materialization is so expensive:
| Level | Size | Bandwidth | Latency |
|---|---|---|---|
| Registers | ~256 KB per SM | - | 1 cycle |
| SRAM (Shared Memory) | 192 KB per SM (A100) | ~19 TB/s | ~30 cycles |
| L2 Cache | 40 MB total | ~5 TB/s | ~200 cycles |
| HBM (Global Memory) | 80 GB | 2 TB/s | ~400 cycles |
SM = Streaming Multiprocessor, the GPU's basic compute unit. An A100 has 108 SMs.
The key insight: SRAM is 10x faster than HBM, but it's 400x smaller. Standard attention writes the N x N matrix to HBM because it doesn't fit in SRAM. FlashAttention reorganizes the computation so it never needs to.
Hover or focus a memory level to see details
Key insight: SRAM is 10× faster than HBM but 400× smaller. FlashAttention tiles computations to keep data in SRAM.
Before tackling the memory efficiency of attention itself, let's address a different kind of waste: recomputing the same keys and values during autoregressive generation.
When an LLM generates text, it produces one token at a time. At each step, the model attends to all previous tokens. Without caching, this means recomputing K and V for the entire history at every generation step.
Consider generating 100 tokens:
Total K/V computations: 1 + 2 + 3 + ... + 100 = 5,050. But we only needed 100 unique computations.
Key Insight
The keys and values for position t don't change when generating position t+1. Cache them, and you reduce O(N^2) computation to O(N).
The idea is simple: store previously computed keys and values, and only compute K/V for new tokens.
Without KV Cache:
For each new token:
Compute Q, K, V for ALL tokens (0 to t)
Attention(Q, K, V)
Return last token's output
With KV Cache:
Initialize empty cache
For the prompt:
Compute K, V for all prompt tokens
Store in cache
For each new token:
Compute Q, K, V for ONLY the new token
Concatenate K_new, V_new to cache
Attention(Q_new, K_cached, V_cached)
Return output
Here's the core modification to a multi-head attention module:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Cache for keys and values
self.register_buffer("cache_k", None)
self.register_buffer("cache_v", None)
def forward(self, x, use_cache=False):
batch, seq_len, _ = x.shape
# Compute Q, K, V projections for new tokens
q = self.W_q(x)
k_new = self.W_k(x)
v_new = self.W_v(x)
if use_cache:
if self.cache_k is None:
# First call: initialize cache
self.cache_k = k_new
self.cache_v = v_new
else:
# Subsequent calls: append to cache
self.cache_k = torch.cat([self.cache_k, k_new], dim=1)
self.cache_v = torch.cat([self.cache_v, v_new], dim=1)
# Use full cache for attention
k, v = self.cache_k, self.cache_v
else:
k, v = k_new, v_new
# Standard attention computation
# q: (batch, 1, d_model) for generation
# k, v: (batch, cached_len, d_model)
attn_out = scaled_dot_product_attention(q, k, v)
return self.W_o(attn_out)
def reset_cache(self):
self.cache_k = None
self.cache_v = None
Don't Forget to Reset
Always reset the KV cache between different prompts/sequences. Forgetting this causes the model to attend to stale context from previous generations, producing incoherent outputs.
Advantages:
Disadvantages:
KV cache memory per token per layer:
For Llama 2 70B (80 layers, 64 heads, head dim 128, FP16):
At 4096 context length: 32 KB x 80 layers x 4096 tokens = 10.5 GB per sequence
This is why techniques like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) matter: they share K/V heads to reduce cache size by 8-32x.
Note: At sequence length 2,048, standard attention materializes a 2,048×2,048 = 4,194,304 element matrix per head per layer. FlashAttention avoids this entirely through tiling.
KV caching handles redundant computation. But each individual attention computation still materializes that N x N matrix. FlashAttention fixes this.
Key Insight
FlashAttention computes exact attention (not an approximation). It just reorganizes the computation to avoid materializing the full attention matrix, reducing HBM accesses dramatically.
Instead of computing the full attention matrix at once, FlashAttention processes it in tiles that fit in fast SRAM. The algorithm:
The challenge is softmax: it requires knowing the maximum and sum over the entire row to normalize. How do you compute softmax when you only see parts of the row at a time?
Key insight: Instead of materializing the full 8×8 = 64 element attention matrix in HBM, FlashAttention processes it in 2×2 tiles that fit in fast SRAM. Total tiles: 16, each only 4 elements.
Here's the challenge: softmax needs global information. To compute softmax(x_i), you need:
But we're processing tiles. When we see tile 1, we don't know what values are in tile 2. How can we compute a correct softmax without seeing the whole row?
The key insight is that softmax(x) = exp(x - m) / sum(exp(x - m)) remains valid for any value of m—we just need to use the same m everywhere.
When we discover a larger maximum m_new, we can "correct" our previous sum by multiplying by exp(m_old - m_new). The math relies on a simple identity: exp(a + b) = exp(a) × exp(b). Therefore:
exp(x - m_old) = exp(x - m_new + m_new - m_old) = exp(x - m_new) × exp(m_new - m_old)
This shows how values computed with the old max can be converted to the new max with a single multiplication.
This correction is the key to processing tiles separately: if tile 2 has a larger max than tile 1, we retroactively adjust tile 1's results with a single multiply per element. That's the cost of processing in chunks instead of all at once—and it's exact, not an approximation.
Standard softmax requires two passes:
# Pass 1: find max for numerical stability
m = max(scores)
# Pass 2: compute exp and sum
exp_scores = exp(scores - m)
output = exp_scores / sum(exp_scores)
Online softmax computes this incrementally. When processing a new block of scores, we can update running statistics:
# Current state: (m_prev, l_prev, acc_prev)
# New block: scores_new
# Step 1: Update maximum
m_new = max(m_prev, max(scores_new))
# Step 2: Compute correction factor for previous accumulator
correction = exp(m_prev - m_new)
# Step 3: Update running sum with corrected old + new
l_new = correction * l_prev + sum(exp(scores_new - m_new))
# Step 4: Update running output (rescale old output + add new)
acc_new = correction * acc_prev + exp(scores_new - m_new) @ v_new
Key insight: Standard softmax needs two passes (find max, then compute). Online softmax does it in one pass by tracking a running max and correcting previous values when a new max is found.
Numerical verification: Let's process scores [2, 5, 3] in two blocks: [2, 5] then [3].
Block 1 ([2, 5]):
m = 5
l = exp(2-5) + exp(5-5) = 0.05 + 1.0 = 1.05
Block 2 ([3]):
m_new = max(5, 3) = 5 (unchanged)
correction = exp(5 - 5) = 1.0
l_new = 1.0 × 1.05 + exp(3-5) = 1.05 + 0.14 = 1.19
Standard softmax: sum(exp([2,5,3] - 5)) = 0.05 + 1.0 + 0.14 = 1.19 ✓
The online algorithm produces the exact same result! This allows computing exact softmax while only keeping O(block_size) values in SRAM instead of O(N).
This is where FlashAttention's theory is elegant. Let:
Standard attention HBM accesses: O(Nd + N^2)
FlashAttention HBM accesses: O(N^2 d^2 / M)
Where does this come from? With M bytes of SRAM and d-dimensional vectors, we can fit roughly M/d vectors per tile. Here's the step-by-step:
tile_size ~ M/d vectors
number of tile pairs ~ (N / tile_size)² = (Nd/M)²
each tile pair transfers ~ tile_size × d = M bytes
total HBM accesses ~ (Nd/M)² × M = N²d²/M
For typical values (N=2048, d=64, M=96KB SRAM):
Theoretical Optimality
Dao et al. proved that FlashAttention is asymptotically optimal: no algorithm can have fewer than Omega(N^2 d^2 / M) HBM accesses for computing exact attention. The only way to do better is to approximate attention.
Here's the complete algorithm. The key points to notice are: (1) the outer loop over query blocks, (2) the inner loop over key/value blocks, and (3) the online softmax update that happens inside the inner loop.
def flash_attention(Q, K, V, block_size_q, block_size_kv):
"""
Q: (N, d) queries
K: (N, d) keys
V: (N, d) values
Returns: (N, d) attention output
"""
N, d = Q.shape
O = zeros((N, d)) # Output accumulator
L = zeros(N) # Normalizer (sum of exp)
M = full(N, -inf) # Running maximum
# Outer loop: iterate over query blocks
for i in range(0, N, block_size_q):
q_block = Q[i:i+block_size_q] # Load Q block to SRAM
o_block = zeros((block_size_q, d))
l_block = zeros(block_size_q)
m_block = full(block_size_q, -inf)
# Inner loop: iterate over key/value blocks
for j in range(0, N, block_size_kv):
k_block = K[j:j+block_size_kv] # Load K block to SRAM
v_block = V[j:j+block_size_kv] # Load V block to SRAM
# Compute attention scores for this tile
scores = q_block @ k_block.T / sqrt(d)
# Online softmax update
m_new = maximum(m_block, scores.max(dim=1))
# Rescale previous accumulator
scale = exp(m_block - m_new)
l_block = scale * l_block + exp(scores - m_new[:, None]).sum(dim=1)
# Update output accumulator
o_block = scale[:, None] * o_block + exp(scores - m_new[:, None]) @ v_block
m_block = m_new
# Final normalization and write to HBM
O[i:i+block_size_q] = o_block / l_block[:, None]
return O
The key logic is in the inner loop (the for j in range block): each iteration updates the running maximum, rescales the previous accumulator, and adds the new contribution. When the inner loop completes, o_block contains the exact same result as standard attention would produce—but we never allocated an N × N matrix.
Choosing Block Sizes
Block size determines a trade-off: larger blocks (128-256) mean fewer SRAM trips but higher memory usage per block. Smaller blocks (32-64) allow more concurrent warps but increase transfer overhead. Typical values on A100: block_q=128, block_kv=64, fitting within the ~96KB SRAM budget. The actual FlashAttention implementation adapts block sizes based on available SRAM.
From the FlashAttention paper (Dao et al., 2022):
| Benchmark | Standard Attention | FlashAttention | Speedup |
|---|---|---|---|
| Attention alone (seq 2K) | 1x | 7.6x | 7.6x |
| GPT-2 training (seq 1K) | 1x | 3x | 3x |
| BERT-large (seq 512) | 1x | 1.15x | 15% |
| Long Range Arena (seq 4K) | 1x | 2.4x | 2.4x |
Note: BERT-large's smaller speedup (15%) reflects sequence length 512, where 512² = 262K values (~512 KB) largely fits in A100's L2 cache. At this scale, tiling overhead nears the memory savings.
Memory savings at sequence length 4K: 20x reduction (no N^2 materialization).
What About the Backward Pass?
FlashAttention's backward pass uses similar tiling techniques to achieve O(N²d²/M) HBM accesses. The implementation is more complex (requires careful handling of recomputation vs. storage for gradients), but the asymptotic complexity is similar to the forward pass.
When FlashAttention Doesn't Help
FlashAttention has overhead from its tiling logic. For very short sequences (N < 256) where the N×N matrix already fits comfortably in SRAM, standard attention may be faster. Also, some implementations have head dimension constraints (e.g., d ≤ 128 or must be a multiple of 8)—check your library's documentation. FlashAttention shines at longer sequences where the quadratic memory would otherwise be prohibitive.
FlashAttention 1 reduced HBM accesses dramatically, but GPU profiling revealed a surprise: utilization was still only 25-40% of theoretical peak. FlashAttention 2 (Dao, 2023) pushed this to 50-73% through better work partitioning.
Quick GPU Terminology
Before diving in, here's a brief refresher on GPU concepts:
Here's a key insight: on modern GPUs, non-matmul FLOPs are 16x more expensive than matmul FLOPs in terms of throughput.
| Operation | A100 Throughput |
|---|---|
| FP16 Matrix Multiply | 312 TFLOPs |
| FP32 Scalar Exp | 19.5 TFLOPs |
This 16x gap exists because GPUs have specialized Tensor Cores for matrix multiply, but operations like exp() and division run on general-purpose ALUs. FlashAttention 1 spent too many cycles on softmax-related operations. FlashAttention 2 reduces these by:
FlashAttention 1 used "split-K" parallelism: different warps processed different parts of K/V and synchronized via shared memory to combine results.
FlashAttention 2 flips this: split Q across warps while each warp processes the full K/V. This eliminates synchronization between warps and reduces shared memory reads/writes.
FlashAttention 1 (split-K):
Warp 0: Q_full x K_part1 → partial_output_1 ─┐
Warp 1: Q_full x K_part2 → partial_output_2 ─┼→ sync → combine → output
Warp 2: Q_full x K_part3 → partial_output_3 ─┤
Warp 3: Q_full x K_part4 → partial_output_4 ─┘
FlashAttention 2 (split-Q):
Warp 0: Q_part1 x K_full → output_1 (no sync needed!)
Warp 1: Q_part2 x K_full → output_2
Warp 2: Q_part3 x K_full → output_3
Warp 3: Q_part4 x K_full → output_4
Sequence-level parallelism: FA1 only parallelized across batch and heads. FA2 also parallelizes across the sequence dimension, improving utilization for small batch sizes.
Reduced shared memory traffic: By keeping more state in registers, FA2 cuts shared memory reads/writes significantly.
Better occupancy: Careful tuning of block sizes to balance register usage and SM occupancy.
| Configuration | FA1 | FA2 | Speedup |
|---|---|---|---|
| Forward pass, seq 2K | 124 TFLOPs | 230 TFLOPs | 1.9x |
| End-to-end training | - | 225 TFLOPs | - |
| GPU utilization | 25-40% | 50-73% | ~2x |
FlashAttention 2 achieves approximately 9x speedup over PyTorch standard attention on long sequences (from Dao, 2023) and 2x speedup over FlashAttention 1.
FlashAttention 2's optimizations were designed for the A100 architecture. When NVIDIA released the H100 with fundamentally new hardware features, new optimization opportunities emerged that FlashAttention 2 couldn't exploit. FlashAttention 3 (Dao et al., 2024) targets these specifically, achieving 75% utilization and 740 TFLOPs on H100.
Note
This section covers H100-specific optimizations. You don't need to understand the details to use FlashAttention 3—just pass attn_implementation="flash_attention_2" to HuggingFace models on H100 and you'll get FA3 automatically. But if you're curious about why it's 2x faster, read on.
The H100 introduced hardware specifically designed for transformer workloads:
Instead of having all warps do the same work, FA3 assigns different roles:
This overlaps data movement with computation, hiding memory latency.
The expensive non-matmul operations (softmax exp, rescaling) happen between matrix multiplies. FA3 interleaves two warpgroups:
Time →
Warpgroup 0: [GEMM K0] [softmax] [GEMM V0] [softmax] [GEMM K2] ...
Warpgroup 1: [GEMM K1] [softmax] [GEMM V1] [softmax] ...
While one warpgroup does softmax, the other does GEMM. This hides the throughput gap between these operations—on H100, this gap reaches ~250x (989 TFLOPs for matmul vs ~4 TFLOPs for scalar ops), making overlap even more critical than on A100's 16x gap.
H100 supports FP8 (8-bit floating point) with 2x the throughput of FP16. But naive FP8 attention has high quantization error due to outliers in attention scores.
FA3 uses incoherent processing: apply Hadamard transforms (structured orthogonal matrices that redistribute values across dimensions) to Q and K before quantization. This spreads outliers across dimensions rather than concentrating them, reducing quantization error by 2.6x while maintaining speed benefits.
| Precision | FlashAttention 2 | FlashAttention 3 | Speedup |
|---|---|---|---|
| FP16 | ~370 TFLOPs | 740 TFLOPs | 2.0x |
| FP8 | N/A | ~1,200 TFLOPs | - |
FA3 reaches 75% of H100's theoretical FP16 peak, up from 35% with FA2.
FA3 is specifically optimized for H100 GPUs and requires CUDA 12+. If you're on A100 or older hardware, stick with FA2. The good news: PyTorch's scaled_dot_product_attention automatically selects the best available backend.
Since PyTorch 2.0, FlashAttention is built-in:
import torch
import torch.nn.functional as F
# Automatic backend selection (FlashAttention when available)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None, # Optional attention mask
dropout_p=0.0, # Dropout probability
is_causal=True, # Use causal masking (for autoregressive)
)
PyTorch automatically selects the best backend:
For more control or features not in PyTorch:
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# query, key, value: (batch, seqlen, nheads, headdim)
output = flash_attn_func(
query, key, value,
dropout_p=0.0,
softmax_scale=None, # Default: 1/sqrt(headdim)
causal=True,
)
Most model libraries support FlashAttention automatically:
# HuggingFace Transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # Enable FlashAttention
torch_dtype=torch.float16,
)
These optimizations are complementary:
For maximum performance, use both:
# Example: HuggingFace with both optimizations
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
)
outputs = model.generate(
input_ids,
max_new_tokens=100,
use_cache=True, # Enable KV caching
)
Understanding why attention is slow is the first step to fixing it. The key insights:
Attention is memory-bound, not compute-bound. The N^2 attention matrix causes memory bandwidth to become the bottleneck, not compute throughput.
KV caching eliminates redundant computation. By caching keys and values, we reduce per-step compute from O(N) to O(1), giving 5-10x speedups for generation.
FlashAttention is exact attention, not an approximation. It uses tiling and online softmax to avoid materializing the N^2 matrix, reducing memory access by 10-20x.
Hardware-aware optimization matters. FlashAttention 2 and 3 exploit specific GPU features (warp partitioning, TMA, WGMMA) to achieve 50-75% utilization where naive implementations achieve 25-30%.
Use both KV caching and FlashAttention together. They solve different problems and stack multiplicatively.
For most users, the practical advice is simple: use PyTorch 2.0+ with scaled_dot_product_attention, enable KV caching for generation, and the framework handles the rest.
torch.profiler to identify whether attention is your bottleneckTORCH_LOGS="+dynamo" to see backend selectiontorch.cuda.memory_summary() to see where memory goesUnderstanding these techniques transforms you from someone who throws hardware at performance problems to someone who can diagnose and fix them at the algorithmic level.
Remember that 30% GPU utilization we started with? Now you know why: the attention matrix was forcing O(N²) memory traffic, starving the compute units. Enable FlashAttention and that utilization jumps to 75%. Add KV caching and you eliminate redundant computation entirely. The same hardware, the same model—just algorithms that respect the memory hierarchy.
For educational purposes, here's a minimal tiled attention implementation that demonstrates the core FlashAttention concept without CUDA:
import torch
import math
def tiled_attention(Q, K, V, block_size=256):
"""
Simplified tiled attention for educational purposes.
Not optimized for production - use flash_attn or PyTorch's SDPA instead.
Q, K, V: (batch, seq_len, d_model)
Returns: (batch, seq_len, d_model)
"""
batch, seq_len, d_model = Q.shape
scale = 1.0 / math.sqrt(d_model)
# Output and normalization accumulators
O = torch.zeros_like(Q)
L = torch.zeros(batch, seq_len, 1, device=Q.device)
M = torch.full((batch, seq_len, 1), float('-inf'), device=Q.device)
# Iterate over query blocks
for i in range(0, seq_len, block_size):
q_end = min(i + block_size, seq_len)
q_block = Q[:, i:q_end, :] # (batch, block_q, d)
# Local accumulators for this query block
o_block = torch.zeros_like(q_block)
l_block = torch.zeros(batch, q_end - i, 1, device=Q.device)
m_block = torch.full((batch, q_end - i, 1), float('-inf'), device=Q.device)
# Iterate over key/value blocks
for j in range(0, seq_len, block_size):
k_end = min(j + block_size, seq_len)
k_block = K[:, j:k_end, :] # (batch, block_kv, d)
v_block = V[:, j:k_end, :]
# Compute attention scores: (batch, block_q, block_kv)
scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale
# Online softmax update
m_new = torch.maximum(m_block, scores.max(dim=-1, keepdim=True).values)
# Rescale previous accumulator
exp_diff = torch.exp(m_block - m_new)
l_block = exp_diff * l_block + torch.exp(scores - m_new).sum(dim=-1, keepdim=True)
# Update output: rescale old + add new contribution
o_block = exp_diff * o_block + torch.bmm(torch.exp(scores - m_new), v_block)
m_block = m_new
# Final normalization and store
O[:, i:q_end, :] = o_block / l_block
return O
# Test correctness
def test_tiled_attention():
torch.manual_seed(42)
batch, seq_len, d_model = 2, 1024, 64
Q = torch.randn(batch, seq_len, d_model)
K = torch.randn(batch, seq_len, d_model)
V = torch.randn(batch, seq_len, d_model)
# Reference: standard attention
scale = 1.0 / math.sqrt(d_model)
attn_weights = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) * scale, dim=-1)
reference = torch.bmm(attn_weights, V)
# Our tiled implementation
tiled = tiled_attention(Q, K, V, block_size=128)
# Check correctness
max_diff = (reference - tiled).abs().max().item()
print(f"Max difference from reference: {max_diff:.2e}")
assert max_diff < 1e-5, "Tiled attention differs from reference!"
print("Test passed!")
if __name__ == "__main__":
test_tiled_attention()
This implementation shows the key ideas (tiling, online softmax) but runs in Python and won't be faster than standard attention. The real speedup comes from the CUDA implementation that keeps tiles in SRAM and fuses all operations into a single kernel.
Papers:
Educational Resources:
Code: