viva_tensor/nn/flash_attention
Flash Attention - IO-Aware Exact Attention in O(n) Memory
“The key insight is that memory bandwidth, not FLOPs, is the bottleneck.” — Tri Dao, channeling every GPU programmer’s frustration
References:
- Dao et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” https://arxiv.org/abs/2205.14135
- Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” https://arxiv.org/abs/2307.08691
- Rabe & Staats (2021). “Self-attention Does Not Need O(n^2) Memory.” The theoretical foundation that made Flash Attention possible.
The Problem: Standard attention: scores = Q @ K^T (this creates an n x n matrix!) For n=8192: 67M elements = 256MB per head. 32 heads = 8GB. Ouch.
The Solution: Process in TILES. Never materialize the full n x n matrix. Online softmax: update running statistics incrementally. Result: O(n) memory, 2-4x faster, and EXACT (not an approximation!).
Why it works (IO-awareness): GPU has fast SRAM (on-chip) and slow HBM (off-chip). Standard attention: write n^2 elements to HBM, read them back. Slow. Flash attention: keep working set in SRAM, only read/write O(n) to HBM. Memory bandwidth wins over raw FLOPs. Every. Single. Time.
This implementation is a pure Gleam demonstration. For production, you’d want CUDA kernels that fuse the operations. But the math is the same.
Types
Flash Attention configuration.
Block sizes determine the tile dimensions. Larger blocks = fewer iterations but more SRAM usage. The sweet spot depends on your GPU’s SRAM size. For A100: block_q=128, block_kv=128 works well. For consumer GPUs: 64x64 is safer.
pub type FlashConfig {
FlashConfig(
block_q: Int,
block_kv: Int,
scale: Float,
causal: Bool,
)
}
Constructors
-
FlashConfig( block_q: Int, block_kv: Int, scale: Float, causal: Bool, )Arguments
- block_q
-
Block size for Q dimension (rows of attention matrix)
- block_kv
-
Block size for KV dimension (columns of attention matrix)
- scale
-
Scaling factor: 1/sqrt(d_k). Keeps attention weights from exploding.
- causal
-
Causal masking for autoregressive models (GPT-style)
Result of Flash Attention, with memory statistics.
pub type FlashResult {
FlashResult(
output: tensor.Tensor,
memory_bytes: Int,
memory_saved_percent: Float,
)
}
Constructors
-
FlashResult( output: tensor.Tensor, memory_bytes: Int, memory_saved_percent: Float, )Arguments
- output
-
The attention output tensor
- memory_bytes
-
Peak memory usage in bytes (just the tile, not the full matrix)
- memory_saved_percent
-
Percentage of memory saved vs naive attention
Running statistics for online softmax computation.
The magic of Flash Attention: we don’t need all the scores to compute softmax. We track max (for numerical stability) and sum_exp (for normalization), updating them as we process each KV block.
Math: softmax(x)_i = exp(x_i - max(x)) / sum_j(exp(x_j - max(x))) We can compute this incrementally by rescaling when max changes.
pub type OnlineStats {
OnlineStats(
max_val: Float,
sum_exp: Float,
output: List(Float),
)
}
Constructors
-
OnlineStats(max_val: Float, sum_exp: Float, output: List(Float))Arguments
- max_val
-
Running maximum (for numerical stability in exp)
- sum_exp
-
Running sum of exp(score - max) for normalization
- output
-
Accumulated output (will be normalized at the end)
Values
pub fn benchmark_flash_attention() -> Nil
pub fn causal_config(head_dim: Int) -> FlashConfig
Causal configuration for autoregressive models.
In causal attention, position i can only attend to positions j <= i. This is how GPT, LLaMA, and friends generate text token by token.
pub fn default_config(head_dim: Int) -> FlashConfig
Default configuration optimized for common use cases.
Block sizes of 64 work on most GPUs. Scale is 1/sqrt(head_dim), following the original Transformer paper.
pub fn flash_attention(
q: tensor.Tensor,
k: tensor.Tensor,
v: tensor.Tensor,
config: FlashConfig,
) -> FlashResult
Flash Attention: exact attention with O(n) memory.
This is the algorithm that enabled 100K+ context windows in LLMs. It’s not an approximation - it computes the exact same result as naive attention. The magic is in the order of computation and the online softmax trick.
pub fn naive_attention(
q: tensor.Tensor,
k: tensor.Tensor,
v: tensor.Tensor,
scale: Float,
) -> #(tensor.Tensor, Int)
Standard attention with O(n^2) memory. DON’T use for long sequences.
This allocates the full attention matrix. For n=8192, that’s 256MB per head. Included only to show what Flash Attention saves you from.