Skip to main content
Skip to content

Attention Optimization: From Memory Walls to Flash Attention

You deployed your first LLM. Inference is 10x slower than you expected, and your GPU utilization hovers around 30%. You throw money at bigger GPUs, but throughput barely improves. What went wrong?

The bottleneck isn't compute. It's memory bandwidth.

I've seen this pattern repeatedly when teams first deploy transformers at scale. They profile their model expecting to find slow matrix multiplies, but instead discover that attention spends most of its time waiting for data to move between memory hierarchies—specifically, reading from and writing to HBM (High Bandwidth Memory), the GPU's main memory. An A100 can perform 312 trillion FP16 operations per second, but HBM can only transfer 2 terabytes of data per second. When your algorithm requires moving N^2 bytes to compute N^2 operations, you're leaving most of that compute power idle.

This article explains why attention is memory-bound and how three techniques fix it: KV caching for eliminating redundant computation, and FlashAttention (1, 2, and 3) for making the remaining computation IO-efficient. By the end, you'll understand the memory hierarchy that governs GPU performance, the tiling tricks that make FlashAttention work, and when to use each optimization.

What you'll learn:

  • Why attention is memory-bound, not compute-bound
  • How KV caching eliminates O(N) redundant key/value computations
  • How FlashAttention reduces memory access from O(N^2) to O(N^2 d^2 / M), where M is SRAM size
  • The key innovations in FlashAttention 2 and 3
  • Practical decision framework for which optimizations to use

Prerequisites: Familiarity with the transformer attention mechanism (softmax(QK^T/sqrt(d))V), basic understanding of GPU architecture (what CUDA cores and memory are), and comfort reading PyTorch code.

A note on scope: This article focuses on the algorithmic ideas behind attention optimization. I've simplified some GPU architecture details where the core insight doesn't depend on them. For the full implementation details, the FlashAttention papers (linked at the end) are excellent.

With that context, let's start with practical guidance—then build up to the theory.

#Quick Decision Framework

Before diving into the details, here's what to use when:

SituationRecommendation
Any production inferenceEnable KV cache (always)
PyTorch 2.0+Use scaled_dot_product_attention (includes FlashAttention)
Training or long sequencesUse FlashAttention library directly
H100 GPUsUse FlashAttention-3 for best performance
Serving many concurrent requestsAdd PagedAttention (vLLM)
Need to reduce KV cache memoryConsider GQA or MQA architectures

#The Memory Wall

To understand why your GPU sits idle at 30% utilization, we need to look at what attention actually does to memory. Let's start with why attention is slow.

The standard attention formula looks innocent enough:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

(The √d scaling is critical: without it, dot products grow with dimension, making softmax nearly one-hot. The scaling keeps attention scores in a reasonable range regardless of head dimension.)

For a sequence of length N with head dimension d, this requires:

  1. Compute QK^T: O(N^2 d) FLOPs, producing an N x N matrix
  2. Apply softmax row-wise: O(N^2) operations
  3. Multiply by V: O(N^2 d) FLOPs

Total compute: O(N^2 d). That's a lot of operations, but modern GPUs are fast at matrix multiplication. An A100 can do 312 TFLOPs (trillion floating-point operations) of FP16 math per second.

Here's the problem: the standard implementation materializes (allocates and stores) the N x N attention matrix in GPU memory. For a sequence length of 8,192 with 32 attention heads, that's:

81922×32×2 bytes=4.3 GB per layer8192^2 \times 32 \times 2 \text{ bytes} = 4.3 \text{ GB per layer}

And you need to read this matrix back to multiply by V. With 2 TB/s memory bandwidth, just moving this data takes 2+ milliseconds per layer. A 32-layer model spends 70+ milliseconds just on memory transfers for attention alone.

#Arithmetic Intensity: The Real Bottleneck

The concept that explains this is arithmetic intensity: the ratio of compute operations to bytes moved.

Arithmetic Intensity=FLOPsBytes Transferred\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Transferred}}

For a GPU to be compute-bound (using all its FLOPs), the arithmetic intensity of your algorithm must exceed the GPU's balance point—the minimum FLOPs per byte needed to saturate compute. Below this threshold, memory bandwidth is the bottleneck; above it, compute can run at full speed.

The balance point is simply: Peak FLOPs ÷ Memory Bandwidth.

GPUMemory BandwidthPeak FP16Balance Point
A1002.0 TB/s312 TFLOPs312 ÷ 2 = 156 FLOPs/byte
H1003.35 TB/s989 TFLOPs989 ÷ 3.35 ≈ 295 FLOPs/byte

Standard attention's arithmetic intensity for the softmax step? About 1 FLOP per byte. Let's see where that comes from:

Softmax over N elements:
- FLOPs: ~4N (find max, subtract max, compute exp, divide by sum)
- Memory: Read N values (2N bytes in FP16) + Write N values (2N bytes)
- Arithmetic Intensity: 4N FLOPs / 4N bytes = 1 FLOP/byte

We need 156x more compute per byte to saturate an A100. Softmax is dramatically below this threshold—it's almost entirely memory-bound.

#The Memory Hierarchy

Understanding GPU memory hierarchy explains why materialization is so expensive:

LevelSizeBandwidthLatency
Registers~256 KB per SM-1 cycle
SRAM (Shared Memory)192 KB per SM (A100)~19 TB/s~30 cycles
L2 Cache40 MB total~5 TB/s~200 cycles
HBM (Global Memory)80 GB2 TB/s~400 cycles

SM = Streaming Multiprocessor, the GPU's basic compute unit. An A100 has 108 SMs.

The key insight: SRAM is 10x faster than HBM, but it's 400x smaller. Standard attention writes the N x N matrix to HBM because it doesn't fit in SRAM. FlashAttention reorganizes the computation so it never needs to.

GPU Memory Hierarchy

Registers
~256 KB/SM∞ (on-chip)
SRAM (Shared Memory)
192-228 KB/SM~19 TB/s
L2 Cache
40-50 MB~5 TB/s
HBM (Global Memory)
40-80 GB2-3 TB/s
Faster
|
Larger

Hover or focus a memory level to see details

Key insight: SRAM is 10× faster than HBM but 400× smaller. FlashAttention tiles computations to keep data in SRAM.

#KV Cache: Eliminating Redundant Computation

Before tackling the memory efficiency of attention itself, let's address a different kind of waste: recomputing the same keys and values during autoregressive generation.

When an LLM generates text, it produces one token at a time. At each step, the model attends to all previous tokens. Without caching, this means recomputing K and V for the entire history at every generation step.

Consider generating 100 tokens:

  • Step 1: Compute K, V for token 1
  • Step 2: Recompute K, V for tokens 1-2 (token 1 was already computed!)
  • Step 3: Recompute K, V for tokens 1-3 (tokens 1-2 were already computed!)
  • ...
  • Step 100: Recompute K, V for tokens 1-100

Total K/V computations: 1 + 2 + 3 + ... + 100 = 5,050. But we only needed 100 unique computations.

View:
Generated tokens (step 0/5):
The
quick
brown
fox
jumps
KV Cache:
Cache empty
Total K/V computations:
Without Cache
0
= 1+2+...+0 = O(N²)
With Cache
0
= 0 = O(N)

#How KV Caching Works

The idea is simple: store previously computed keys and values, and only compute K/V for new tokens.

Without KV Cache:

For each new token:
    Compute Q, K, V for ALL tokens (0 to t)
    Attention(Q, K, V)
    Return last token's output

With KV Cache:

Initialize empty cache

For the prompt:
    Compute K, V for all prompt tokens
    Store in cache

For each new token:
    Compute Q, K, V for ONLY the new token
    Concatenate K_new, V_new to cache
    Attention(Q_new, K_cached, V_cached)
    Return output

#Implementation

Here's the core modification to a multi-head attention module:

#1. Register Cache Buffers

Python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        # Cache for keys and values
        self.register_buffer("cache_k", None)
        self.register_buffer("cache_v", None)

#2. Forward Pass with Caching

Python
def forward(self, x, use_cache=False):
    batch, seq_len, _ = x.shape

    # Compute Q, K, V projections for new tokens
    q = self.W_q(x)
    k_new = self.W_k(x)
    v_new = self.W_v(x)

    if use_cache:
        if self.cache_k is None:
            # First call: initialize cache
            self.cache_k = k_new
            self.cache_v = v_new
        else:
            # Subsequent calls: append to cache
            self.cache_k = torch.cat([self.cache_k, k_new], dim=1)
            self.cache_v = torch.cat([self.cache_v, v_new], dim=1)

        # Use full cache for attention
        k, v = self.cache_k, self.cache_v
    else:
        k, v = k_new, v_new

    # Standard attention computation
    # q: (batch, 1, d_model) for generation
    # k, v: (batch, cached_len, d_model)
    attn_out = scaled_dot_product_attention(q, k, v)
    return self.W_o(attn_out)

#3. Reset Between Sequences

Python
def reset_cache(self):
    self.cache_k = None
    self.cache_v = None

#KV Cache Trade-offs

Advantages:

  • (+) Eliminates O(N) redundant key/value computations per step
  • (+) Reduces total generation compute from O(N^2) to O(N)
  • (+) 5-10x speedup for typical generation lengths
  • (+) Simple to implement

Disadvantages:

  • (-) Memory grows linearly with sequence length
  • (-) For batch size 512, context 2048, Llama-70B: KV cache requires ~2.7 TB (32 KB × 80 layers × 2048 × 512 = 2.7 TB—that's ~19× the model's 140 GB weight memory)
  • (-) Adds code complexity for position tracking
  • (-) Must manage cache lifecycle carefully

#Memory Estimates

KV cache memory per token per layer:

Memory=2×nheads×dhead×precision_bytes\text{Memory} = 2 \times n_{\text{heads}} \times d_{\text{head}} \times \text{precision\_bytes}

For Llama 2 70B (80 layers, 64 heads, head dim 128, FP16):

2×64×128×2=32 KB per token per layer2 \times 64 \times 128 \times 2 = 32 \text{ KB per token per layer}

At 4096 context length: 32 KB x 80 layers x 4096 tokens = 10.5 GB per sequence

This is why techniques like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) matter: they share K/V heads to reduce cache size by 8-32x.

KV Cache & Attention Memory Calculator

KV Cache Memory
Per token per layer:16.00 KB
Total KV Cache:1.00 GB
Attention Memory (all layers)
Standard (N×N):8.00 GB
FlashAttention:48.00 MB
Memory Saved:7.95 GB

Note: At sequence length 2,048, standard attention materializes a 2,048×2,048 = 4,194,304 element matrix per head per layer. FlashAttention avoids this entirely through tiling.

KV Cache Formula:
2 × heads × head_dim × precision_bytes × layers × seq_len × batch_size

#FlashAttention 1: IO-Aware Exact Attention

KV caching handles redundant computation. But each individual attention computation still materializes that N x N matrix. FlashAttention fixes this.

#The Core Idea: Tiling

Instead of computing the full attention matrix at once, FlashAttention processes it in tiles that fit in fast SRAM. The algorithm:

  1. Divide Q, K, V into blocks that fit in SRAM
  2. For each block of Q:
    • Load Q block to SRAM
    • For each block of K, V:
      • Load K, V blocks to SRAM
      • Compute partial attention scores
      • Update running output using online softmax
    • Write output block to HBM
  3. Never materialize the full N x N matrix

The challenge is softmax: it requires knowing the maximum and sum over the entire row to normalize. How do you compute softmax when you only see parts of the row at a time?

FlashAttention Tiling Visualization

Attention Matrix (N×N)
K →
Q →
SRAM Contents
Q block:
[0:2]
K block:
[0:2]
V block:
[0:2]
Tile scores: 2×2 matrix
Progress
0/16
Current Tile
(0, 0)
SRAM Usage
4 elements
Standard Attention
64 elements

Key insight: Instead of materializing the full 8×8 = 64 element attention matrix in HBM, FlashAttention processes it in 2×2 tiles that fit in fast SRAM. Total tiles: 16, each only 4 elements.

#Online Softmax: The Clever Trick

Here's the challenge: softmax needs global information. To compute softmax(x_i), you need:

  1. The maximum value across ALL elements (for numerical stability)
  2. The sum of exp(x_j - max) across ALL elements (for normalization)

But we're processing tiles. When we see tile 1, we don't know what values are in tile 2. How can we compute a correct softmax without seeing the whole row?

The key insight is that softmax(x) = exp(x - m) / sum(exp(x - m)) remains valid for any value of m—we just need to use the same m everywhere.

When we discover a larger maximum m_new, we can "correct" our previous sum by multiplying by exp(m_old - m_new). The math relies on a simple identity: exp(a + b) = exp(a) × exp(b). Therefore:

exp(x - m_old) = exp(x - m_new + m_new - m_old) = exp(x - m_new) × exp(m_new - m_old)

This shows how values computed with the old max can be converted to the new max with a single multiplication.

This correction is the key to processing tiles separately: if tile 2 has a larger max than tile 1, we retroactively adjust tile 1's results with a single multiply per element. That's the cost of processing in chunks instead of all at once—and it's exact, not an approximation.

Standard softmax requires two passes:

# Pass 1: find max for numerical stability
m = max(scores)
# Pass 2: compute exp and sum
exp_scores = exp(scores - m)
output = exp_scores / sum(exp_scores)

Online softmax computes this incrementally. When processing a new block of scores, we can update running statistics:

Python
# Current state: (m_prev, l_prev, acc_prev)
# New block: scores_new

# Step 1: Update maximum
m_new = max(m_prev, max(scores_new))

# Step 2: Compute correction factor for previous accumulator
correction = exp(m_prev - m_new)

# Step 3: Update running sum with corrected old + new
l_new = correction * l_prev + sum(exp(scores_new - m_new))

# Step 4: Update running output (rescale old output + add new)
acc_new = correction * acc_prev + exp(scores_new - m_new) @ v_new

Online Softmax Algorithm

Attention scores (split into 3 blocks):
Block 1
2.10.5-1.20.8
Block 2
1.53.20.3-0.4
Block 3
-0.80.92.51.1
Running State (stored in SRAM):
Running Max (m)
-∞
Running Sum (ℓ)
0.0000
Blocks Processed
0/3

Key insight: Standard softmax needs two passes (find max, then compute). Online softmax does it in one pass by tracking a running max and correcting previous values when a new max is found.

Numerical verification: Let's process scores [2, 5, 3] in two blocks: [2, 5] then [3].

Block 1 ([2, 5]):
  m = 5
  l = exp(2-5) + exp(5-5) = 0.05 + 1.0 = 1.05

Block 2 ([3]):
  m_new = max(5, 3) = 5 (unchanged)
  correction = exp(5 - 5) = 1.0
  l_new = 1.0 × 1.05 + exp(3-5) = 1.05 + 0.14 = 1.19

Standard softmax: sum(exp([2,5,3] - 5)) = 0.05 + 1.0 + 0.14 = 1.19 ✓

The online algorithm produces the exact same result! This allows computing exact softmax while only keeping O(block_size) values in SRAM instead of O(N).

#IO Complexity Analysis

This is where FlashAttention's theory is elegant. Let:

  • N = sequence length
  • d = head dimension
  • M = SRAM size

Standard attention HBM accesses: O(Nd + N^2)

  • Read Q, K, V: O(Nd)
  • Write and read N x N attention matrix: O(N^2)

FlashAttention HBM accesses: O(N^2 d^2 / M)

Where does this come from? With M bytes of SRAM and d-dimensional vectors, we can fit roughly M/d vectors per tile. Here's the step-by-step:

tile_size ~ M/d vectors
number of tile pairs ~ (N / tile_size)² = (Nd/M)²
each tile pair transfers ~ tile_size × d = M bytes
total HBM accesses ~ (Nd/M)² × M = N²d²/M

For typical values (N=2048, d=64, M=96KB SRAM):

  • Standard: O(N^2) ≈ 4 million accesses
  • FlashAttention: O(N^2 d^2 / M) ≈ 175K accesses
  • Reduction: ~23x

#FlashAttention Pseudocode

Here's the complete algorithm. The key points to notice are: (1) the outer loop over query blocks, (2) the inner loop over key/value blocks, and (3) the online softmax update that happens inside the inner loop.

Python
def flash_attention(Q, K, V, block_size_q, block_size_kv):
    """
    Q: (N, d) queries
    K: (N, d) keys
    V: (N, d) values
    Returns: (N, d) attention output
    """
    N, d = Q.shape
    O = zeros((N, d))  # Output accumulator
    L = zeros(N)       # Normalizer (sum of exp)
    M = full(N, -inf)  # Running maximum

    # Outer loop: iterate over query blocks
    for i in range(0, N, block_size_q):
        q_block = Q[i:i+block_size_q]  # Load Q block to SRAM

        o_block = zeros((block_size_q, d))
        l_block = zeros(block_size_q)
        m_block = full(block_size_q, -inf)

        # Inner loop: iterate over key/value blocks
        for j in range(0, N, block_size_kv):
            k_block = K[j:j+block_size_kv]  # Load K block to SRAM
            v_block = V[j:j+block_size_kv]  # Load V block to SRAM

            # Compute attention scores for this tile
            scores = q_block @ k_block.T / sqrt(d)

            # Online softmax update
            m_new = maximum(m_block, scores.max(dim=1))

            # Rescale previous accumulator
            scale = exp(m_block - m_new)
            l_block = scale * l_block + exp(scores - m_new[:, None]).sum(dim=1)

            # Update output accumulator
            o_block = scale[:, None] * o_block + exp(scores - m_new[:, None]) @ v_block
            m_block = m_new

        # Final normalization and write to HBM
        O[i:i+block_size_q] = o_block / l_block[:, None]

    return O

The key logic is in the inner loop (the for j in range block): each iteration updates the running maximum, rescales the previous accumulator, and adds the new contribution. When the inner loop completes, o_block contains the exact same result as standard attention would produce—but we never allocated an N × N matrix.

#Performance Results

From the FlashAttention paper (Dao et al., 2022):

BenchmarkStandard AttentionFlashAttentionSpeedup
Attention alone (seq 2K)1x7.6x7.6x
GPT-2 training (seq 1K)1x3x3x
BERT-large (seq 512)1x1.15x15%
Long Range Arena (seq 4K)1x2.4x2.4x

Note: BERT-large's smaller speedup (15%) reflects sequence length 512, where 512² = 262K values (~512 KB) largely fits in A100's L2 cache. At this scale, tiling overhead nears the memory savings.

Memory savings at sequence length 4K: 20x reduction (no N^2 materialization).

#FlashAttention 2: Better Parallelism

FlashAttention 1 reduced HBM accesses dramatically, but GPU profiling revealed a surprise: utilization was still only 25-40% of theoretical peak. FlashAttention 2 (Dao, 2023) pushed this to 50-73% through better work partitioning.

#The Problem: Non-Matmul Operations

Here's a key insight: on modern GPUs, non-matmul FLOPs are 16x more expensive than matmul FLOPs in terms of throughput.

OperationA100 Throughput
FP16 Matrix Multiply312 TFLOPs
FP32 Scalar Exp19.5 TFLOPs

This 16x gap exists because GPUs have specialized Tensor Cores for matrix multiply, but operations like exp() and division run on general-purpose ALUs. FlashAttention 1 spent too many cycles on softmax-related operations. FlashAttention 2 reduces these by:

  1. Reordering loops to minimize rescaling operations
  2. Keeping running statistics in registers instead of SRAM

#The Fix: Better Warp Partitioning

FlashAttention 1 used "split-K" parallelism: different warps processed different parts of K/V and synchronized via shared memory to combine results.

FlashAttention 2 flips this: split Q across warps while each warp processes the full K/V. This eliminates synchronization between warps and reduces shared memory reads/writes.

FlashAttention 1 (split-K):
  Warp 0: Q_full x K_part1 → partial_output_1 ─┐
  Warp 1: Q_full x K_part2 → partial_output_2 ─┼→ sync → combine → output
  Warp 2: Q_full x K_part3 → partial_output_3 ─┤
  Warp 3: Q_full x K_part4 → partial_output_4 ─┘

FlashAttention 2 (split-Q):
  Warp 0: Q_part1 x K_full → output_1  (no sync needed!)
  Warp 1: Q_part2 x K_full → output_2
  Warp 2: Q_part3 x K_full → output_3
  Warp 3: Q_part4 x K_full → output_4

#Additional Optimizations

  1. Sequence-level parallelism: FA1 only parallelized across batch and heads. FA2 also parallelizes across the sequence dimension, improving utilization for small batch sizes.

  2. Reduced shared memory traffic: By keeping more state in registers, FA2 cuts shared memory reads/writes significantly.

  3. Better occupancy: Careful tuning of block sizes to balance register usage and SM occupancy.

#Performance Results

ConfigurationFA1FA2Speedup
Forward pass, seq 2K124 TFLOPs230 TFLOPs1.9x
End-to-end training-225 TFLOPs-
GPU utilization25-40%50-73%~2x

FlashAttention 2 achieves approximately 9x speedup over PyTorch standard attention on long sequences (from Dao, 2023) and 2x speedup over FlashAttention 1.

#FlashAttention 3: Hopper-Specific Optimizations

FlashAttention 2's optimizations were designed for the A100 architecture. When NVIDIA released the H100 with fundamentally new hardware features, new optimization opportunities emerged that FlashAttention 2 couldn't exploit. FlashAttention 3 (Dao et al., 2024) targets these specifically, achieving 75% utilization and 740 TFLOPs on H100.

#New H100 Features

The H100 introduced hardware specifically designed for transformer workloads:

  1. WGMMA (Warpgroup Matrix Multiply-Accumulate): A new matrix multiply instruction that processes larger tiles than A100's equivalent, roughly doubling throughput for the same operation
  2. TMA (Tensor Memory Accelerator): A dedicated hardware unit for memory transfers. Previously, GPU cores had to spend cycles calculating addresses and issuing loads. TMA handles this automatically, freeing cores to compute
  3. Larger shared memory: 228 KB per SM vs 192 KB on A100, allowing larger tiles

#Three Key Techniques

#1. Warp Specialization

Instead of having all warps do the same work, FA3 assigns different roles:

  • Producer warps: Use TMA to asynchronously fetch next K/V blocks
  • Consumer warps: Compute attention on current blocks

This overlaps data movement with computation, hiding memory latency.

#2. Pingpong Scheduling

The expensive non-matmul operations (softmax exp, rescaling) happen between matrix multiplies. FA3 interleaves two warpgroups:

Time →
Warpgroup 0: [GEMM K0] [softmax] [GEMM V0] [softmax] [GEMM K2] ...
Warpgroup 1:          [GEMM K1] [softmax] [GEMM V1] [softmax] ...

While one warpgroup does softmax, the other does GEMM. This hides the throughput gap between these operations—on H100, this gap reaches ~250x (989 TFLOPs for matmul vs ~4 TFLOPs for scalar ops), making overlap even more critical than on A100's 16x gap.

#3. FP8 with Incoherent Processing

H100 supports FP8 (8-bit floating point) with 2x the throughput of FP16. But naive FP8 attention has high quantization error due to outliers in attention scores.

FA3 uses incoherent processing: apply Hadamard transforms (structured orthogonal matrices that redistribute values across dimensions) to Q and K before quantization. This spreads outliers across dimensions rather than concentrating them, reducing quantization error by 2.6x while maintaining speed benefits.

#Performance Results

PrecisionFlashAttention 2FlashAttention 3Speedup
FP16~370 TFLOPs740 TFLOPs2.0x
FP8N/A~1,200 TFLOPs-

FA3 reaches 75% of H100's theoretical FP16 peak, up from 35% with FA2.

#When to Use FA3

FA3 is specifically optimized for H100 GPUs and requires CUDA 12+. If you're on A100 or older hardware, stick with FA2. The good news: PyTorch's scaled_dot_product_attention automatically selects the best available backend.

#Using FlashAttention in Practice

Since PyTorch 2.0, FlashAttention is built-in:

Python
import torch
import torch.nn.functional as F

# Automatic backend selection (FlashAttention when available)
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,      # Optional attention mask
    dropout_p=0.0,       # Dropout probability
    is_causal=True,      # Use causal masking (for autoregressive)
)

PyTorch automatically selects the best backend:

  1. FlashAttention (if supported)
  2. Memory-efficient attention (xFormers-style)
  3. Standard attention (fallback)

#Direct FlashAttention Library

For more control or features not in PyTorch:

Bash
pip install flash-attn --no-build-isolation
Python
from flash_attn import flash_attn_func

# query, key, value: (batch, seqlen, nheads, headdim)
output = flash_attn_func(
    query, key, value,
    dropout_p=0.0,
    softmax_scale=None,  # Default: 1/sqrt(headdim)
    causal=True,
)

#Integration with Transformers

Most model libraries support FlashAttention automatically:

Python
# HuggingFace Transformers
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",  # Enable FlashAttention
    torch_dtype=torch.float16,
)

#Combining KV Cache and FlashAttention

These optimizations are complementary:

  • KV Cache: Reduces redundant K/V computation from O(N^2) to O(N) across generation steps
  • FlashAttention: Reduces memory access from O(N^2) to O(N^2 d^2/M) for each attention computation

For maximum performance, use both:

Python
# Example: HuggingFace with both optimizations
model = AutoModelForCausalLM.from_pretrained(
    "model-name",
    attn_implementation="flash_attention_2",
)

outputs = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True,  # Enable KV caching
)

#Conclusion

Understanding why attention is slow is the first step to fixing it. The key insights:

  1. Attention is memory-bound, not compute-bound. The N^2 attention matrix causes memory bandwidth to become the bottleneck, not compute throughput.

  2. KV caching eliminates redundant computation. By caching keys and values, we reduce per-step compute from O(N) to O(1), giving 5-10x speedups for generation.

  3. FlashAttention is exact attention, not an approximation. It uses tiling and online softmax to avoid materializing the N^2 matrix, reducing memory access by 10-20x.

  4. Hardware-aware optimization matters. FlashAttention 2 and 3 exploit specific GPU features (warp partitioning, TMA, WGMMA) to achieve 50-75% utilization where naive implementations achieve 25-30%.

  5. Use both KV caching and FlashAttention together. They solve different problems and stack multiplicatively.

For most users, the practical advice is simple: use PyTorch 2.0+ with scaled_dot_product_attention, enable KV caching for generation, and the framework handles the rest.

#Your Next Steps

  • Profile your model with torch.profiler to identify whether attention is your bottleneck
  • Verify FlashAttention is being used: Set TORCH_LOGS="+dynamo" to see backend selection
  • For serving workloads, explore vLLM or TensorRT-LLM which integrate these optimizations with batching and PagedAttention
  • Monitor memory: Use torch.cuda.memory_summary() to see where memory goes

Understanding these techniques transforms you from someone who throws hardware at performance problems to someone who can diagnose and fix them at the algorithmic level.


#Bonus: Simple Tiled Attention in Pure PyTorch

For educational purposes, here's a minimal tiled attention implementation that demonstrates the core FlashAttention concept without CUDA:

Python
import torch
import math

def tiled_attention(Q, K, V, block_size=256):
    """
    Simplified tiled attention for educational purposes.
    Not optimized for production - use flash_attn or PyTorch's SDPA instead.

    Q, K, V: (batch, seq_len, d_model)
    Returns: (batch, seq_len, d_model)
    """
    batch, seq_len, d_model = Q.shape
    scale = 1.0 / math.sqrt(d_model)

    # Output and normalization accumulators
    O = torch.zeros_like(Q)
    L = torch.zeros(batch, seq_len, 1, device=Q.device)
    M = torch.full((batch, seq_len, 1), float('-inf'), device=Q.device)

    # Iterate over query blocks
    for i in range(0, seq_len, block_size):
        q_end = min(i + block_size, seq_len)
        q_block = Q[:, i:q_end, :]  # (batch, block_q, d)

        # Local accumulators for this query block
        o_block = torch.zeros_like(q_block)
        l_block = torch.zeros(batch, q_end - i, 1, device=Q.device)
        m_block = torch.full((batch, q_end - i, 1), float('-inf'), device=Q.device)

        # Iterate over key/value blocks
        for j in range(0, seq_len, block_size):
            k_end = min(j + block_size, seq_len)
            k_block = K[:, j:k_end, :]  # (batch, block_kv, d)
            v_block = V[:, j:k_end, :]

            # Compute attention scores: (batch, block_q, block_kv)
            scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale

            # Online softmax update
            m_new = torch.maximum(m_block, scores.max(dim=-1, keepdim=True).values)

            # Rescale previous accumulator
            exp_diff = torch.exp(m_block - m_new)
            l_block = exp_diff * l_block + torch.exp(scores - m_new).sum(dim=-1, keepdim=True)

            # Update output: rescale old + add new contribution
            o_block = exp_diff * o_block + torch.bmm(torch.exp(scores - m_new), v_block)
            m_block = m_new

        # Final normalization and store
        O[:, i:q_end, :] = o_block / l_block

    return O

# Test correctness
def test_tiled_attention():
    torch.manual_seed(42)
    batch, seq_len, d_model = 2, 1024, 64

    Q = torch.randn(batch, seq_len, d_model)
    K = torch.randn(batch, seq_len, d_model)
    V = torch.randn(batch, seq_len, d_model)

    # Reference: standard attention
    scale = 1.0 / math.sqrt(d_model)
    attn_weights = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) * scale, dim=-1)
    reference = torch.bmm(attn_weights, V)

    # Our tiled implementation
    tiled = tiled_attention(Q, K, V, block_size=128)

    # Check correctness
    max_diff = (reference - tiled).abs().max().item()
    print(f"Max difference from reference: {max_diff:.2e}")
    assert max_diff < 1e-5, "Tiled attention differs from reference!"
    print("Test passed!")

if __name__ == "__main__":
    test_tiled_attention()

This implementation shows the key ideas (tiling, online softmax) but runs in Python and won't be faster than standard attention. The real speedup comes from the CUDA implementation that keeps tiles in SRAM and fuses all operations into a single kernel.

#Further Reading

Papers:

  • Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. arXiv:2205.14135
  • Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. arXiv:2307.08691
  • Dao et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." NeurIPS 2024. arXiv:2407.08608
  • Shazeer (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150
  • Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. arXiv:2305.13245
  • Kwon et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023. arXiv:2309.06180

Educational Resources:

  • Sebastian Raschka: "Understanding and Coding the KV Cache in LLMs from Scratch"
  • Tri Dao's FlashAttention-3 blog: tridao.me/blog/2024/flash3/
  • Jay Alammar: "The Illustrated Transformer"

Code: