Attention Optimization: From Memory Walls to Flash Attention
You deployed your first LLM. Inference is 10x slower than you expected, and your GPU utilization hovers around 30%. You throw money at bigger GPUs, but throughput barely improves. What went wrong?
The bottleneck isn't compute. It's memory bandwidth.
I've seen this pattern repeatedly when teams first deploy transformers at scale. They profile their model expecting to find slow matrix multiplies, but instead discover that attention spends most of its time waiting for data to move between memory hierarchies—specifically, reading from and writing to HBM (High Bandwidth Memory), the GPU's main memory. An A100 can perform 312 trillion FP16 operations per second, but HBM can only transfer 2 terabytes of data per second. When your algorithm requires moving N^2 bytes to compute N^2 operations, you're leaving most of that compute power idle.
This article explains why attention is memory-bound and how three techniques fix it: KV caching for eliminating redundant computation, and FlashAttention (1, 2, and 3) for making the remaining computation IO-efficient. By the end, you'll understand the memory hierarchy that governs GPU performance, the tiling tricks that make FlashAttention work, and when to use each optimization.
What you'll learn:
- Why attention is memory-bound, not compute-bound
- How KV caching eliminates O(N) redundant key/value computations
- How FlashAttention reduces memory access from O(N^2) to O(N^2 d^2 / M), where M is SRAM size
- The key innovations in FlashAttention 2 and 3
- Practical decision framework for which optimizations to use
Prerequisites: Familiarity with the transformer attention mechanism (softmax(QK^T/sqrt(d))V), basic understanding of GPU architecture (what CUDA cores and memory are), and comfort reading PyTorch code.
A note on scope: This article focuses on the algorithmic ideas behind attention optimization. I've simplified some GPU architecture details where the core insight doesn't depend on them. For the full implementation details, the FlashAttention papers (linked at the end) are excellent.
With that context, let's start with practical guidance—then build up to the theory.
Quick Decision Framework
Before diving into the details, here's what to use when:
| Situation | Recommendation |
|---|---|
| Any production inference | Enable KV cache (always) |
| PyTorch 2.0+ | Use scaled_dot_product_attention (includes FlashAttention) |
| Training or long sequences | Use FlashAttention library directly |
| H100 GPUs | Use FlashAttention-3 for best performance |
| Serving many concurrent requests | Add PagedAttention (vLLM) |
| Need to reduce KV cache memory | Consider GQA or MQA architectures |
Key Insight
KV caching and FlashAttention solve different problems. KV caching eliminates redundant computation across generation steps. FlashAttention makes each attention computation more memory-efficient. Use both together for maximum performance.
The Memory Wall
To understand why your GPU sits idle at 30% utilization, we need to look at what attention actually does to memory. Let's start with why attention is slow.
The standard attention formula looks innocent enough:
(The √d scaling is critical: without it, dot products grow with dimension, making softmax nearly one-hot. The scaling keeps attention scores in a reasonable range regardless of head dimension.)
For a sequence of length N with head dimension d, this requires:
- Compute QK^T: O(N^2 d) FLOPs, producing an N x N matrix
- Apply softmax row-wise: O(N^2) operations
- Multiply by V: O(N^2 d) FLOPs
Total compute: O(N^2 d). That's a lot of operations, but modern GPUs are fast at matrix multiplication. An A100 can do 312 TFLOPs (trillion floating-point operations) of FP16 math per second.
Here's the problem: the standard implementation materializes (allocates and stores) the N x N attention matrix in GPU memory. For a sequence length of 8,192 with 32 attention heads, that's:
And you need to read this matrix back to multiply by V. With 2 TB/s memory bandwidth, just moving this data takes 2+ milliseconds per layer. A 32-layer model spends 70+ milliseconds just on memory transfers for attention alone.
Arithmetic Intensity: The Real Bottleneck
The concept that explains this is arithmetic intensity: the ratio of compute operations to bytes moved.
For a GPU to be compute-bound (using all its FLOPs), the arithmetic intensity of your algorithm must exceed the GPU's balance point—the minimum FLOPs per byte needed to saturate compute. Below this threshold, memory bandwidth is the bottleneck; above it, compute can run at full speed.
The balance point is simply: Peak FLOPs ÷ Memory Bandwidth.
| GPU | Memory Bandwidth | Peak FP16 | Balance Point |
|---|---|---|---|
| A100 | 2.0 TB/s | 312 TFLOPs | 312 ÷ 2 = 156 FLOPs/byte |
| H100 | 3.35 TB/s | 989 TFLOPs | 989 ÷ 3.35 ≈ 295 FLOPs/byte |
Standard attention's arithmetic intensity for the softmax step? About 1 FLOP per byte. Let's see where that comes from:
Softmax over N elements:
- FLOPs: ~4N (find max, subtract max, compute exp, divide by sum)
- Memory: Read N values (2N bytes in FP16) + Write N values (2N bytes)
- Arithmetic Intensity: 4N FLOPs / 4N bytes = 1 FLOP/byte
We need 156x more compute per byte to saturate an A100. Softmax is dramatically below this threshold—it's almost entirely memory-bound.
Why Matrix Multiply is Different
Matrix multiplication has high arithmetic intensity because of data reuse: each element is used in multiple operations. Multiplying two N×N matrices requires O(N³) FLOPs but only O(N²) memory accesses—each element participates in N different multiply-accumulate operations. That's O(N) FLOPs per byte, which exceeds the balance point for large N. This is why matrix multiply runs at 70-80% GPU utilization while attention struggles at 30%.
The Memory Hierarchy
Understanding GPU memory hierarchy explains why materialization is so expensive:
| Level | Size | Bandwidth | Latency |
|---|---|---|---|
| Registers | ~256 KB per SM | - | 1 cycle |
| SRAM (Shared Memory) | 192 KB per SM (A100) | ~19 TB/s | ~30 cycles |
| L2 Cache | 40 MB total | ~5 TB/s | ~200 cycles |
| HBM (Global Memory) | 80 GB | 2 TB/s | ~400 cycles |
SM = Streaming Multiprocessor, the GPU's basic compute unit. An A100 has 108 SMs.
The key insight: SRAM is 10x faster than HBM, but it's 400x smaller. Standard attention writes the N x N matrix to HBM because it doesn't fit in SRAM. FlashAttention reorganizes the computation so it never needs to.
GPU Memory Hierarchy
Hover or focus a memory level to see details
Key insight: SRAM is 10× faster than HBM but 400× smaller. FlashAttention tiles computations to keep data in SRAM.
KV Cache: Eliminating Redundant Computation
Before tackling the memory efficiency of attention itself, let's address a different kind of waste: recomputing the same keys and values during autoregressive generation.
When an LLM generates text, it produces one token at a time. At each step, the model attends to all previous tokens. Without caching, this means recomputing K and V for the entire history at every generation step.
Consider generating 100 tokens:
- Step 1: Compute K, V for token 1
- Step 2: Recompute K, V for tokens 1-2 (token 1 was already computed!)
- Step 3: Recompute K, V for tokens 1-3 (tokens 1-2 were already computed!)
- ...
- Step 100: Recompute K, V for tokens 1-100
Total K/V computations: 1 + 2 + 3 + ... + 100 = 5,050. But we only needed 100 unique computations.
Key Insight
The keys and values for position t don't change when generating position t+1. Cache them, and you reduce O(N^2) computation to O(N).
How KV Caching Works
The idea is simple: store previously computed keys and values, and only compute K/V for new tokens.
Without KV Cache:
For each new token:
Compute Q, K, V for ALL tokens (0 to t)
Attention(Q, K, V)
Return last token's output
With KV Cache:
Initialize empty cache
For the prompt:
Compute K, V for all prompt tokens
Store in cache
For each new token:
Compute Q, K, V for ONLY the new token
Concatenate K_new, V_new to cache
Attention(Q_new, K_cached, V_cached)
Return output
Implementation
Here's the core modification to a multi-head attention module:
1. Register Cache Buffers
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Cache for keys and values
self.register_buffer("cache_k", None)
self.register_buffer("cache_v", None)
2. Forward Pass with Caching
def forward(self, x, use_cache=False):
batch, seq_len, _ = x.shape
# Compute Q, K, V projections for new tokens
q = self.W_q(x)
k_new = self.W_k(x)
v_new = self.W_v(x)
if use_cache:
if self.cache_k is None:
# First call: initialize cache
self.cache_k = k_new
self.cache_v = v_new
else:
# Subsequent calls: append to cache
self.cache_k = torch.cat([self.cache_k, k_new], dim=1)
self.cache_v = torch.cat([self.cache_v, v_new], dim=1)
# Use full cache for attention
k, v = self.cache_k, self.cache_v
else:
k, v = k_new, v_new
# Standard attention computation
# q: (batch, 1, d_model) for generation
# k, v: (batch, cached_len, d_model)
attn_out = scaled_dot_product_attention(q, k, v)
return self.W_o(attn_out)
3. Reset Between Sequences
def reset_cache(self):
self.cache_k = None
self.cache_v = None
Don't Forget to Reset
Always reset the KV cache between different prompts/sequences. Forgetting this causes the model to attend to stale context from previous generations, producing incoherent outputs.
KV Cache Trade-offs
Advantages:
- (+) Eliminates O(N) redundant key/value computations per step
- (+) Reduces total generation compute from O(N^2) to O(N)
- (+) 5-10x speedup for typical generation lengths
- (+) Simple to implement
Disadvantages:
- (-) Memory grows linearly with sequence length
- (-) For batch size 512, context 2048, Llama-70B: KV cache requires ~2.7 TB (32 KB × 80 layers × 2048 × 512 = 2.7 TB—that's ~19× the model's 140 GB weight memory)
- (-) Adds code complexity for position tracking
- (-) Must manage cache lifecycle carefully
Memory Estimates
KV cache memory per token per layer:
For Llama 2 70B (80 layers, 64 heads, head dim 128, FP16):
At 4096 context length: 32 KB x 80 layers x 4096 tokens = 10.5 GB per sequence
This is why techniques like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) matter: they share K/V heads to reduce cache size by 8-32x.
KV Cache & Attention Memory Calculator
KV Cache Memory
Attention Memory (all layers)
Note: At sequence length 2,048, standard attention materializes a 2,048×2,048 = 4,194,304 element matrix per head per layer. FlashAttention avoids this entirely through tiling.
FlashAttention 1: IO-Aware Exact Attention
KV caching handles redundant computation. But each individual attention computation still materializes that N x N matrix. FlashAttention fixes this.
Key Insight
FlashAttention computes exact attention (not an approximation). It just reorganizes the computation to avoid materializing the full attention matrix, reducing HBM accesses dramatically.
The Core Idea: Tiling
Instead of computing the full attention matrix at once, FlashAttention processes it in tiles that fit in fast SRAM. The algorithm:
- Divide Q, K, V into blocks that fit in SRAM
- For each block of Q:
- Load Q block to SRAM
- For each block of K, V:
- Load K, V blocks to SRAM
- Compute partial attention scores
- Update running output using online softmax
- Write output block to HBM
- Never materialize the full N x N matrix
The challenge is softmax: it requires knowing the maximum and sum over the entire row to normalize. How do you compute softmax when you only see parts of the row at a time?
FlashAttention Tiling Visualization
Key insight: Instead of materializing the full 8×8 = 64 element attention matrix in HBM, FlashAttention processes it in 2×2 tiles that fit in fast SRAM. Total tiles: 16, each only 4 elements.
Online Softmax: The Clever Trick
Here's the challenge: softmax needs global information. To compute softmax(x_i), you need:
- The maximum value across ALL elements (for numerical stability)
- The sum of exp(x_j - max) across ALL elements (for normalization)
But we're processing tiles. When we see tile 1, we don't know what values are in tile 2. How can we compute a correct softmax without seeing the whole row?
The key insight is that softmax(x) = exp(x - m) / sum(exp(x - m)) remains valid for any value of m—we just need to use the same m everywhere.
When we discover a larger maximum m_new, we can "correct" our previous sum by multiplying by exp(m_old - m_new). The math relies on a simple identity: exp(a + b) = exp(a) × exp(b). Therefore:
exp(x - m_old) = exp(x - m_new + m_new - m_old) = exp(x - m_new) × exp(m_new - m_old)
This shows how values computed with the old max can be converted to the new max with a single multiplication.
This correction is the key to processing tiles separately: if tile 2 has a larger max than tile 1, we retroactively adjust tile 1's results with a single multiply per element. That's the cost of processing in chunks instead of all at once—and it's exact, not an approximation.
Standard softmax requires two passes:
# Pass 1: find max for numerical stability
m = max(scores)
# Pass 2: compute exp and sum
exp_scores = exp(scores - m)
output = exp_scores / sum(exp_scores)
Online softmax computes this incrementally. When processing a new block of scores, we can update running statistics:
# Current state: (m_prev, l_prev, acc_prev)
# New block: scores_new
# Step 1: Update maximum
m_new = max(m_prev, max(scores_new))
# Step 2: Compute correction factor for previous accumulator
correction = exp(m_prev - m_new)
# Step 3: Update running sum with corrected old + new
l_new = correction * l_prev + sum(exp(scores_new - m_new))
# Step 4: Update running output (rescale old output + add new)
acc_new = correction * acc_prev + exp(scores_new - m_new) @ v_new
Online Softmax Algorithm
Key insight: Standard softmax needs two passes (find max, then compute). Online softmax does it in one pass by tracking a running max and correcting previous values when a new max is found.
Numerical verification: Let's process scores [2, 5, 3] in two blocks: [2, 5] then [3].
Block 1 ([2, 5]):
m = 5
l = exp(2-5) + exp(5-5) = 0.05 + 1.0 = 1.05
Block 2 ([3]):
m_new = max(5, 3) = 5 (unchanged)
correction = exp(5 - 5) = 1.0
l_new = 1.0 × 1.05 + exp(3-5) = 1.05 + 0.14 = 1.19
Standard softmax: sum(exp([2,5,3] - 5)) = 0.05 + 1.0 + 0.14 = 1.19 ✓
The online algorithm produces the exact same result! This allows computing exact softmax while only keeping O(block_size) values in SRAM instead of O(N).
IO Complexity Analysis
This is where FlashAttention's theory is elegant. Let:
- N = sequence length
- d = head dimension
- M = SRAM size
Standard attention HBM accesses: O(Nd + N^2)
- Read Q, K, V: O(Nd)
- Write and read N x N attention matrix: O(N^2)
FlashAttention HBM accesses: O(N^2 d^2 / M)
Where does this come from? With M bytes of SRAM and d-dimensional vectors, we can fit roughly M/d vectors per tile. Here's the step-by-step:
tile_size ~ M/d vectors
number of tile pairs ~ (N / tile_size)² = (Nd/M)²
each tile pair transfers ~ tile_size × d = M bytes
total HBM accesses ~ (Nd/M)² × M = N²d²/M
For typical values (N=2048, d=64, M=96KB SRAM):
- Standard: O(N^2) ≈ 4 million accesses
- FlashAttention: O(N^2 d^2 / M) ≈ 175K accesses
- Reduction: ~23x
Theoretical Optimality
Dao et al. proved that FlashAttention is asymptotically optimal: no algorithm can have fewer than Omega(N^2 d^2 / M) HBM accesses for computing exact attention. The only way to do better is to approximate attention.
FlashAttention Pseudocode
Here's the complete algorithm. The key points to notice are: (1) the outer loop over query blocks, (2) the inner loop over key/value blocks, and (3) the online softmax update that happens inside the inner loop.
def flash_attention(Q, K, V, block_size_q, block_size_kv):
"""
Q: (N, d) queries
K: (N, d) keys
V: (N, d) values
Returns: (N, d) attention output
"""
N, d = Q.shape
O = zeros((N, d)) # Output accumulator
L = zeros(N) # Normalizer (sum of exp)
M = full(N, -inf) # Running maximum
# Outer loop: iterate over query blocks
for i in range(0, N, block_size_q):
q_block = Q[i:i+block_size_q] # Load Q block to SRAM
o_block = zeros((block_size_q, d))
l_block = zeros(block_size_q)
m_block = full(block_size_q, -inf)
# Inner loop: iterate over key/value blocks
for j in range(0, N, block_size_kv):
k_block = K[j:j+block_size_kv] # Load K block to SRAM
v_block = V[j:j+block_size_kv] # Load V block to SRAM
# Compute attention scores for this tile
scores = q_block @ k_block.T / sqrt(d)
# Online softmax update
m_new = maximum(m_block, scores.max(dim=1))
# Rescale previous accumulator
scale = exp(m_block - m_new)
l_block = scale * l_block + exp(scores - m_new[:, None]).sum(dim=1)
# Update output accumulator
o_block = scale[:, None] * o_block + exp(scores - m_new[:, None]) @ v_block
m_block = m_new
# Final normalization and write to HBM
O[i:i+block_size_q] = o_block / l_block[:, None]
return O
The key logic is in the inner loop (the for j in range block): each iteration updates the running maximum, rescales the previous accumulator, and adds the new contribution. When the inner loop completes, o_block contains the exact same result as standard attention would produce—but we never allocated an N × N matrix.
Choosing Block Sizes
Block size determines a trade-off: larger blocks (128-256) mean fewer SRAM trips but higher memory usage per block. Smaller blocks (32-64) allow more concurrent warps but increase transfer overhead. Typical values on A100: block_q=128, block_kv=64, fitting within the ~96KB SRAM budget. The actual FlashAttention implementation adapts block sizes based on available SRAM.
Performance Results
From the FlashAttention paper (Dao et al., 2022):
| Benchmark | Standard Attention | FlashAttention | Speedup |
|---|---|---|---|
| Attention alone (seq 2K) | 1x | 7.6x | 7.6x |
| GPT-2 training (seq 1K) | 1x | 3x | 3x |
| BERT-large (seq 512) | 1x | 1.15x | 15% |
| Long Range Arena (seq 4K) | 1x | 2.4x | 2.4x |
Note: BERT-large's smaller speedup (15%) reflects sequence length 512, where 512² = 262K values (~512 KB) largely fits in A100's L2 cache. At this scale, tiling overhead nears the memory savings.
Memory savings at sequence length 4K: 20x reduction (no N^2 materialization).
What About the Backward Pass?
FlashAttention's backward pass uses similar tiling techniques to achieve O(N²d²/M) HBM accesses. The implementation is more complex (requires careful handling of recomputation vs. storage for gradients), but the asymptotic complexity is similar to the forward pass.
When FlashAttention Doesn't Help
FlashAttention has overhead from its tiling logic. For very short sequences (N < 256) where the N×N matrix already fits comfortably in SRAM, standard attention may be faster. Also, some implementations have head dimension constraints (e.g., d ≤ 128 or must be a multiple of 8)—check your library's documentation. FlashAttention shines at longer sequences where the quadratic memory would otherwise be prohibitive.
FlashAttention 2: Better Parallelism
FlashAttention 1 reduced HBM accesses dramatically, but GPU profiling revealed a surprise: utilization was still only 25-40% of theoretical peak. FlashAttention 2 (Dao, 2023) pushed this to 50-73% through better work partitioning.
Quick GPU Terminology
Before diving in, here's a brief refresher on GPU concepts:
- SM (Streaming Multiprocessor): A GPU has many SMs (108 on A100); each runs many threads in parallel
- Warp: A group of 32 threads that execute in lockstep on an SM
- Shared Memory: Fast on-chip SRAM (192KB per SM on A100) shared by all threads on that SM
- Synchronization: When threads must wait for each other before proceeding—expensive because it stalls computation
- Occupancy: What fraction of an SM's resources are actively being used
The Problem: Non-Matmul Operations
Here's a key insight: on modern GPUs, non-matmul FLOPs are 16x more expensive than matmul FLOPs in terms of throughput.
| Operation | A100 Throughput |
|---|---|
| FP16 Matrix Multiply | 312 TFLOPs |
| FP32 Scalar Exp | 19.5 TFLOPs |
This 16x gap exists because GPUs have specialized Tensor Cores for matrix multiply, but operations like exp() and division run on general-purpose ALUs. FlashAttention 1 spent too many cycles on softmax-related operations. FlashAttention 2 reduces these by:
- Reordering loops to minimize rescaling operations
- Keeping running statistics in registers instead of SRAM
The Fix: Better Warp Partitioning
FlashAttention 1 used "split-K" parallelism: different warps processed different parts of K/V and synchronized via shared memory to combine results.
FlashAttention 2 flips this: split Q across warps while each warp processes the full K/V. This eliminates synchronization between warps and reduces shared memory reads/writes.
FlashAttention 1 (split-K):
Warp 0: Q_full x K_part1 → partial_output_1 ─┐
Warp 1: Q_full x K_part2 → partial_output_2 ─┼→ sync → combine → output
Warp 2: Q_full x K_part3 → partial_output_3 ─┤
Warp 3: Q_full x K_part4 → partial_output_4 ─┘
FlashAttention 2 (split-Q):
Warp 0: Q_part1 x K_full → output_1 (no sync needed!)
Warp 1: Q_part2 x K_full → output_2
Warp 2: Q_part3 x K_full → output_3
Warp 3: Q_part4 x K_full → output_4
Additional Optimizations
-
Sequence-level parallelism: FA1 only parallelized across batch and heads. FA2 also parallelizes across the sequence dimension, improving utilization for small batch sizes.
-
Reduced shared memory traffic: By keeping more state in registers, FA2 cuts shared memory reads/writes significantly.
-
Better occupancy: Careful tuning of block sizes to balance register usage and SM occupancy.
Performance Results
| Configuration | FA1 | FA2 | Speedup |
|---|---|---|---|
| Forward pass, seq 2K | 124 TFLOPs | 230 TFLOPs | 1.9x |
| End-to-end training | - | 225 TFLOPs | - |
| GPU utilization | 25-40% | 50-73% | ~2x |
FlashAttention 2 achieves approximately 9x speedup over PyTorch standard attention on long sequences (from Dao, 2023) and 2x speedup over FlashAttention 1.
FlashAttention 3: Hopper-Specific Optimizations
FlashAttention 2's optimizations were designed for the A100 architecture. When NVIDIA released the H100 with fundamentally new hardware features, new optimization opportunities emerged that FlashAttention 2 couldn't exploit. FlashAttention 3 (Dao et al., 2024) targets these specifically, achieving 75% utilization and 740 TFLOPs on H100.
Note
This section covers H100-specific optimizations. You don't need to understand the details to use FlashAttention 3—just pass attn_implementation="flash_attention_2" to HuggingFace models on H100 and you'll get FA3 automatically. But if you're curious about why it's 2x faster, read on.
New H100 Features
The H100 introduced hardware specifically designed for transformer workloads:
- WGMMA (Warpgroup Matrix Multiply-Accumulate): A new matrix multiply instruction that processes larger tiles than A100's equivalent, roughly doubling throughput for the same operation
- TMA (Tensor Memory Accelerator): A dedicated hardware unit for memory transfers. Previously, GPU cores had to spend cycles calculating addresses and issuing loads. TMA handles this automatically, freeing cores to compute
- Larger shared memory: 228 KB per SM vs 192 KB on A100, allowing larger tiles
Three Key Techniques
1. Warp Specialization
Instead of having all warps do the same work, FA3 assigns different roles:
- Producer warps: Use TMA to asynchronously fetch next K/V blocks
- Consumer warps: Compute attention on current blocks
This overlaps data movement with computation, hiding memory latency.
2. Pingpong Scheduling
The expensive non-matmul operations (softmax exp, rescaling) happen between matrix multiplies. FA3 interleaves two warpgroups:
Time →
Warpgroup 0: [GEMM K0] [softmax] [GEMM V0] [softmax] [GEMM K2] ...
Warpgroup 1: [GEMM K1] [softmax] [GEMM V1] [softmax] ...
While one warpgroup does softmax, the other does GEMM. This hides the throughput gap between these operations—on H100, this gap reaches ~250x (989 TFLOPs for matmul vs ~4 TFLOPs for scalar ops), making overlap even more critical than on A100's 16x gap.
3. FP8 with Incoherent Processing
H100 supports FP8 (8-bit floating point) with 2x the throughput of FP16. But naive FP8 attention has high quantization error due to outliers in attention scores.
FA3 uses incoherent processing: apply Hadamard transforms (structured orthogonal matrices that redistribute values across dimensions) to Q and K before quantization. This spreads outliers across dimensions rather than concentrating them, reducing quantization error by 2.6x while maintaining speed benefits.
Performance Results
| Precision | FlashAttention 2 | FlashAttention 3 | Speedup |
|---|---|---|---|
| FP16 | ~370 TFLOPs | 740 TFLOPs | 2.0x |
| FP8 | N/A | ~1,200 TFLOPs | - |
FA3 reaches 75% of H100's theoretical FP16 peak, up from 35% with FA2.
When to Use FA3
FA3 is specifically optimized for H100 GPUs and requires CUDA 12+. If you're on A100 or older hardware, stick with FA2. The good news: PyTorch's scaled_dot_product_attention automatically selects the best available backend.
Using FlashAttention in Practice
PyTorch Native (Recommended)
Since PyTorch 2.0, FlashAttention is built-in:
import torch
import torch.nn.functional as F
# Automatic backend selection (FlashAttention when available)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None, # Optional attention mask
dropout_p=0.0, # Dropout probability
is_causal=True, # Use causal masking (for autoregressive)
)
PyTorch automatically selects the best backend:
- FlashAttention (if supported)
- Memory-efficient attention (xFormers-style)
- Standard attention (fallback)
Direct FlashAttention Library
For more control or features not in PyTorch:
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# query, key, value: (batch, seqlen, nheads, headdim)
output = flash_attn_func(
query, key, value,
dropout_p=0.0,
softmax_scale=None, # Default: 1/sqrt(headdim)
causal=True,
)
Integration with Transformers
Most model libraries support FlashAttention automatically:
# HuggingFace Transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # Enable FlashAttention
torch_dtype=torch.float16,
)
Combining KV Cache and FlashAttention
These optimizations are complementary:
- KV Cache: Reduces redundant K/V computation from O(N^2) to O(N) across generation steps
- FlashAttention: Reduces memory access from O(N^2) to O(N^2 d^2/M) for each attention computation
For maximum performance, use both:
# Example: HuggingFace with both optimizations
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
)
outputs = model.generate(
input_ids,
max_new_tokens=100,
use_cache=True, # Enable KV caching
)
Conclusion
Understanding why attention is slow is the first step to fixing it. The key insights:
-
Attention is memory-bound, not compute-bound. The N^2 attention matrix causes memory bandwidth to become the bottleneck, not compute throughput.
-
KV caching eliminates redundant computation. By caching keys and values, we reduce per-step compute from O(N) to O(1), giving 5-10x speedups for generation.
-
FlashAttention is exact attention, not an approximation. It uses tiling and online softmax to avoid materializing the N^2 matrix, reducing memory access by 10-20x.
-
Hardware-aware optimization matters. FlashAttention 2 and 3 exploit specific GPU features (warp partitioning, TMA, WGMMA) to achieve 50-75% utilization where naive implementations achieve 25-30%.
-
Use both KV caching and FlashAttention together. They solve different problems and stack multiplicatively.
For most users, the practical advice is simple: use PyTorch 2.0+ with scaled_dot_product_attention, enable KV caching for generation, and the framework handles the rest.
Your Next Steps
- Profile your model with
torch.profilerto identify whether attention is your bottleneck - Verify FlashAttention is being used: Set
TORCH_LOGS="+dynamo"to see backend selection - For serving workloads, explore vLLM or TensorRT-LLM which integrate these optimizations with batching and PagedAttention
- Monitor memory: Use
torch.cuda.memory_summary()to see where memory goes
Understanding these techniques transforms you from someone who throws hardware at performance problems to someone who can diagnose and fix them at the algorithmic level.
Bonus: Simple Tiled Attention in Pure PyTorch
For educational purposes, here's a minimal tiled attention implementation that demonstrates the core FlashAttention concept without CUDA:
import torch
import math
def tiled_attention(Q, K, V, block_size=256):
"""
Simplified tiled attention for educational purposes.
Not optimized for production - use flash_attn or PyTorch's SDPA instead.
Q, K, V: (batch, seq_len, d_model)
Returns: (batch, seq_len, d_model)
"""
batch, seq_len, d_model = Q.shape
scale = 1.0 / math.sqrt(d_model)
# Output and normalization accumulators
O = torch.zeros_like(Q)
L = torch.zeros(batch, seq_len, 1, device=Q.device)
M = torch.full((batch, seq_len, 1), float('-inf'), device=Q.device)
# Iterate over query blocks
for i in range(0, seq_len, block_size):
q_end = min(i + block_size, seq_len)
q_block = Q[:, i:q_end, :] # (batch, block_q, d)
# Local accumulators for this query block
o_block = torch.zeros_like(q_block)
l_block = torch.zeros(batch, q_end - i, 1, device=Q.device)
m_block = torch.full((batch, q_end - i, 1), float('-inf'), device=Q.device)
# Iterate over key/value blocks
for j in range(0, seq_len, block_size):
k_end = min(j + block_size, seq_len)
k_block = K[:, j:k_end, :] # (batch, block_kv, d)
v_block = V[:, j:k_end, :]
# Compute attention scores: (batch, block_q, block_kv)
scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale
# Online softmax update
m_new = torch.maximum(m_block, scores.max(dim=-1, keepdim=True).values)
# Rescale previous accumulator
exp_diff = torch.exp(m_block - m_new)
l_block = exp_diff * l_block + torch.exp(scores - m_new).sum(dim=-1, keepdim=True)
# Update output: rescale old + add new contribution
o_block = exp_diff * o_block + torch.bmm(torch.exp(scores - m_new), v_block)
m_block = m_new
# Final normalization and store
O[:, i:q_end, :] = o_block / l_block
return O
# Test correctness
def test_tiled_attention():
torch.manual_seed(42)
batch, seq_len, d_model = 2, 1024, 64
Q = torch.randn(batch, seq_len, d_model)
K = torch.randn(batch, seq_len, d_model)
V = torch.randn(batch, seq_len, d_model)
# Reference: standard attention
scale = 1.0 / math.sqrt(d_model)
attn_weights = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) * scale, dim=-1)
reference = torch.bmm(attn_weights, V)
# Our tiled implementation
tiled = tiled_attention(Q, K, V, block_size=128)
# Check correctness
max_diff = (reference - tiled).abs().max().item()
print(f"Max difference from reference: {max_diff:.2e}")
assert max_diff < 1e-5, "Tiled attention differs from reference!"
print("Test passed!")
if __name__ == "__main__":
test_tiled_attention()
This implementation shows the key ideas (tiling, online softmax) but runs in Python and won't be faster than standard attention. The real speedup comes from the CUDA implementation that keeps tiles in SRAM and fuses all operations into a single kernel.
Further Reading
Papers:
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. arXiv:2205.14135
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. arXiv:2307.08691
- Dao et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." NeurIPS 2024. arXiv:2407.08608
- Shazeer (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. arXiv:2305.13245
- Kwon et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023. arXiv:2309.06180
Educational Resources:
- Sebastian Raschka: "Understanding and Coding the KV Cache in LLMs from Scratch"
- Tri Dao's FlashAttention-3 blog: tridao.me/blog/2024/flash3/
- Jay Alammar: "The Illustrated Transformer"
Code:
- Official FlashAttention: github.com/Dao-AILab/flash-attention
- vLLM (PagedAttention): github.com/vllm-project/vllm