viva_tensor/optim/sparsity

2:4 Structured Sparsity - NVIDIA Tensor Cores

Reference: Mishra et al. (2021) “Accelerating Sparse Deep Neural Networks” https://arxiv.org/abs/2104.08378

Key insight from the paper: 2:4 sparsity achieves 2x theoretical speedup with <1% accuracy loss on ImageNet. In practice, we see ~1.7x due to memory bandwidth limits and kernel launch overhead.

Why 2:4 specifically? NVIDIA’s brilliant constraint: any 2 of 4 elements can be zero. This is structured enough for hardware acceleration but flexible enough to preserve accuracy. The Sparse Tensor Core can skip 50% of MACs while the index overhead is just 2 bits per 4 elements.

Why not arbitrary sparsity? Because sparse matrix formats (CSR, COO, BCSR) have indexing overhead that kills performance for sparsity < 90%. At 50% sparsity, dense ops win. 2:4’s fixed structure eliminates the index explosion problem.

Storage format:

Performance reality check (NVIDIA Ampere):

Types

Pruning metrics for analysis Tracks the quality of the sparsification decision

pub type PruneMetrics {
  PruneMetrics(
    pruned_count: Int,
    total_count: Int,
    approximation_error: Float,
    kept_magnitude_mean: Float,
    pruned_magnitude_mean: Float,
  )
}

Constructors

  • PruneMetrics(
      pruned_count: Int,
      total_count: Int,
      approximation_error: Float,
      kept_magnitude_mean: Float,
      pruned_magnitude_mean: Float,
    )

    Arguments

    pruned_count

    Elements zeroed out

    total_count

    Total elements

    approximation_error

    L1 approximation error (mean absolute difference)

    kept_magnitude_mean

    Mean magnitude of kept elements

    pruned_magnitude_mean

    Mean magnitude of pruned elements (lower = good pruning decisions)

Sparse 2:4 block: 4 elements compressed to 2 non-zeros + position mask This is the fundamental unit of 2:4 sparsity.

pub type Sparse24Block {
  Sparse24Block(values: #(Float, Float), positions: #(Int, Int))
}

Constructors

  • Sparse24Block(values: #(Float, Float), positions: #(Int, Int))

    Arguments

    values

    The 2 non-zero values (survivors of magnitude pruning)

    positions

    2-bit positions (0-3 each), packed. Could be a single u4 in hardware.

Tensor with 2:4 structured sparsity

Sparsity ratio S = (total - nonzero) / total = 0.5 for 2:4 Effective FLOPS with 2:4: 2x theoretical, ~1.7x practical

pub type Sparse24Tensor {
  Sparse24Tensor(
    blocks: List(Sparse24Block),
    shape: List(Int),
    num_elements: Int,
    memory_bytes: Int,
    sparsity_percent: Float,
  )
}

Constructors

  • Sparse24Tensor(
      blocks: List(Sparse24Block),
      shape: List(Int),
      num_elements: Int,
      memory_bytes: Int,
      sparsity_percent: Float,
    )

    Arguments

    blocks

    Sparse blocks covering the entire tensor

    shape

    Original dense shape (for reconstruction)

    num_elements

    Original element count (may not be divisible by 4)

    memory_bytes

    Compressed memory footprint in bytes

    sparsity_percent

    Actual sparsity achieved (always 50% for 2:4)

Values

pub fn benchmark_sparsity() -> Nil
pub fn compute_metrics(
  original: tensor.Tensor,
  sparse: Sparse24Tensor,
) -> PruneMetrics

Compute pruning quality metrics

Key insight: if pruned_magnitude_mean << kept_magnitude_mean, we’re making good pruning decisions. The approximation_error tells us how much information we lost.

pub fn decompress(sparse: Sparse24Tensor) -> tensor.Tensor

Reconstruct dense tensor from 2:4 sparse representation

This is O(n) and allocation-heavy. In CUDA, you’d keep it sparse and let the Tensor Core handle the pattern. On CPU, you often need to decompress for compatibility with dense operations.

pub fn main() -> Nil
pub fn prune_24_gradient(
  weights: tensor.Tensor,
  gradients: tensor.Tensor,
) -> Sparse24Tensor

Gradient-weighted pruning for training scenarios

Importance = |weight * gradient| Intuition: weights that are both large AND changing rapidly matter most. This is better than magnitude alone during fine-tuning.

pub fn prune_24_magnitude(t: tensor.Tensor) -> Sparse24Tensor

Apply 2:4 pruning using magnitude-based selection

Strategy: keep the 2 largest (by absolute value) in each group of 4. This is NVIDIA’s recommended approach and works well empirically.

Theoretical justification: large weights carry more information. Empirical validation: <1% accuracy drop on ImageNet, BERT, GPT-2.

pub fn sparse_matmul(
  sparse_a: Sparse24Tensor,
  dense_b: tensor.Tensor,
) -> #(tensor.Tensor, Float)

Sparse matrix multiplication (simulated)

On Tensor Cores, this skips 50% of multiplications by hardware. Here we decompress and multiply densely for correctness. Returns: (result, theoretical_speedup)

Search Document