viva_tensor/nn/flash_attention

Flash Attention - IO-Aware Exact Attention in O(n) Memory

“The key insight is that memory bandwidth, not FLOPs, is the bottleneck.” — Tri Dao, channeling every GPU programmer’s frustration

References:

The Problem: Standard attention: scores = Q @ K^T (this creates an n x n matrix!) For n=8192: 67M elements = 256MB per head. 32 heads = 8GB. Ouch.

The Solution: Process in TILES. Never materialize the full n x n matrix. Online softmax: update running statistics incrementally. Result: O(n) memory, 2-4x faster, and EXACT (not an approximation!).

Why it works (IO-awareness): GPU has fast SRAM (on-chip) and slow HBM (off-chip). Standard attention: write n^2 elements to HBM, read them back. Slow. Flash attention: keep working set in SRAM, only read/write O(n) to HBM. Memory bandwidth wins over raw FLOPs. Every. Single. Time.

This implementation is a pure Gleam demonstration. For production, you’d want CUDA kernels that fuse the operations. But the math is the same.

Types

Flash Attention configuration.

Block sizes determine the tile dimensions. Larger blocks = fewer iterations but more SRAM usage. The sweet spot depends on your GPU’s SRAM size. For A100: block_q=128, block_kv=128 works well. For consumer GPUs: 64x64 is safer.

pub type FlashConfig {
  FlashConfig(
    block_q: Int,
    block_kv: Int,
    scale: Float,
    causal: Bool,
  )
}

Constructors

  • FlashConfig(
      block_q: Int,
      block_kv: Int,
      scale: Float,
      causal: Bool,
    )

    Arguments

    block_q

    Block size for Q dimension (rows of attention matrix)

    block_kv

    Block size for KV dimension (columns of attention matrix)

    scale

    Scaling factor: 1/sqrt(d_k). Keeps attention weights from exploding.

    causal

    Causal masking for autoregressive models (GPT-style)

Result of Flash Attention, with memory statistics.

pub type FlashResult {
  FlashResult(
    output: tensor.Tensor,
    memory_bytes: Int,
    memory_saved_percent: Float,
  )
}

Constructors

  • FlashResult(
      output: tensor.Tensor,
      memory_bytes: Int,
      memory_saved_percent: Float,
    )

    Arguments

    output

    The attention output tensor

    memory_bytes

    Peak memory usage in bytes (just the tile, not the full matrix)

    memory_saved_percent

    Percentage of memory saved vs naive attention

Running statistics for online softmax computation.

The magic of Flash Attention: we don’t need all the scores to compute softmax. We track max (for numerical stability) and sum_exp (for normalization), updating them as we process each KV block.

Math: softmax(x)_i = exp(x_i - max(x)) / sum_j(exp(x_j - max(x))) We can compute this incrementally by rescaling when max changes.

pub type OnlineStats {
  OnlineStats(
    max_val: Float,
    sum_exp: Float,
    output: List(Float),
  )
}

Constructors

  • OnlineStats(max_val: Float, sum_exp: Float, output: List(Float))

    Arguments

    max_val

    Running maximum (for numerical stability in exp)

    sum_exp

    Running sum of exp(score - max) for normalization

    output

    Accumulated output (will be normalized at the end)

Values

pub fn benchmark_flash_attention() -> Nil
pub fn causal_config(head_dim: Int) -> FlashConfig

Causal configuration for autoregressive models.

In causal attention, position i can only attend to positions j <= i. This is how GPT, LLaMA, and friends generate text token by token.

pub fn default_config(head_dim: Int) -> FlashConfig

Default configuration optimized for common use cases.

Block sizes of 64 work on most GPUs. Scale is 1/sqrt(head_dim), following the original Transformer paper.

pub fn flash_attention(
  q: tensor.Tensor,
  k: tensor.Tensor,
  v: tensor.Tensor,
  config: FlashConfig,
) -> FlashResult

Flash Attention: exact attention with O(n) memory.

This is the algorithm that enabled 100K+ context windows in LLMs. It’s not an approximation - it computes the exact same result as naive attention. The magic is in the order of computation and the online softmax trick.

pub fn main() -> Nil
pub fn naive_attention(
  q: tensor.Tensor,
  k: tensor.Tensor,
  v: tensor.Tensor,
  scale: Float,
) -> #(tensor.Tensor, Int)

Standard attention with O(n^2) memory. DON’T use for long sequences.

This allocates the full attention matrix. For n=8192, that’s 256MB per head. Included only to show what Flash Attention saves you from.

Search Document