viva_tensor/optim/sparsity
2:4 Structured Sparsity - NVIDIA Tensor Cores
Reference: Mishra et al. (2021) “Accelerating Sparse Deep Neural Networks” https://arxiv.org/abs/2104.08378
Key insight from the paper: 2:4 sparsity achieves 2x theoretical speedup with <1% accuracy loss on ImageNet. In practice, we see ~1.7x due to memory bandwidth limits and kernel launch overhead.
Why 2:4 specifically? NVIDIA’s brilliant constraint: any 2 of 4 elements can be zero. This is structured enough for hardware acceleration but flexible enough to preserve accuracy. The Sparse Tensor Core can skip 50% of MACs while the index overhead is just 2 bits per 4 elements.
Why not arbitrary sparsity? Because sparse matrix formats (CSR, COO, BCSR) have indexing overhead that kills performance for sparsity < 90%. At 50% sparsity, dense ops win. 2:4’s fixed structure eliminates the index explosion problem.
Storage format:
- 2 values (FP16: 32 bits total)
- 2-bit mask for positions (4 bits, padded to 8 bits in practice)
- Total: ~40 bits for 4 elements vs 64 bits dense = 1.6x compression
- Real memory savings: ~1.78x after alignment
Performance reality check (NVIDIA Ampere):
- Dense FP16 matmul: ~150 TFLOPS
- Sparse 2:4 FP16 matmul: ~300 TFLOPS (theoretical 2x)
- Actual observed: ~250-280 TFLOPS (1.7-1.9x)
- Bottleneck: memory bandwidth, not compute
Types
Pruning metrics for analysis Tracks the quality of the sparsification decision
pub type PruneMetrics {
PruneMetrics(
pruned_count: Int,
total_count: Int,
approximation_error: Float,
kept_magnitude_mean: Float,
pruned_magnitude_mean: Float,
)
}
Constructors
-
PruneMetrics( pruned_count: Int, total_count: Int, approximation_error: Float, kept_magnitude_mean: Float, pruned_magnitude_mean: Float, )Arguments
- pruned_count
-
Elements zeroed out
- total_count
-
Total elements
- approximation_error
-
L1 approximation error (mean absolute difference)
- kept_magnitude_mean
-
Mean magnitude of kept elements
- pruned_magnitude_mean
-
Mean magnitude of pruned elements (lower = good pruning decisions)
Sparse 2:4 block: 4 elements compressed to 2 non-zeros + position mask This is the fundamental unit of 2:4 sparsity.
pub type Sparse24Block {
Sparse24Block(values: #(Float, Float), positions: #(Int, Int))
}
Constructors
-
Sparse24Block(values: #(Float, Float), positions: #(Int, Int))Arguments
- values
-
The 2 non-zero values (survivors of magnitude pruning)
- positions
-
2-bit positions (0-3 each), packed. Could be a single u4 in hardware.
Tensor with 2:4 structured sparsity
Sparsity ratio S = (total - nonzero) / total = 0.5 for 2:4 Effective FLOPS with 2:4: 2x theoretical, ~1.7x practical
pub type Sparse24Tensor {
Sparse24Tensor(
blocks: List(Sparse24Block),
shape: List(Int),
num_elements: Int,
memory_bytes: Int,
sparsity_percent: Float,
)
}
Constructors
-
Sparse24Tensor( blocks: List(Sparse24Block), shape: List(Int), num_elements: Int, memory_bytes: Int, sparsity_percent: Float, )Arguments
- blocks
-
Sparse blocks covering the entire tensor
- shape
-
Original dense shape (for reconstruction)
- num_elements
-
Original element count (may not be divisible by 4)
- memory_bytes
-
Compressed memory footprint in bytes
- sparsity_percent
-
Actual sparsity achieved (always 50% for 2:4)
Values
pub fn benchmark_sparsity() -> Nil
pub fn compute_metrics(
original: tensor.Tensor,
sparse: Sparse24Tensor,
) -> PruneMetrics
Compute pruning quality metrics
Key insight: if pruned_magnitude_mean << kept_magnitude_mean, we’re making good pruning decisions. The approximation_error tells us how much information we lost.
pub fn decompress(sparse: Sparse24Tensor) -> tensor.Tensor
Reconstruct dense tensor from 2:4 sparse representation
This is O(n) and allocation-heavy. In CUDA, you’d keep it sparse and let the Tensor Core handle the pattern. On CPU, you often need to decompress for compatibility with dense operations.
pub fn prune_24_gradient(
weights: tensor.Tensor,
gradients: tensor.Tensor,
) -> Sparse24Tensor
Gradient-weighted pruning for training scenarios
Importance = |weight * gradient| Intuition: weights that are both large AND changing rapidly matter most. This is better than magnitude alone during fine-tuning.
pub fn prune_24_magnitude(t: tensor.Tensor) -> Sparse24Tensor
Apply 2:4 pruning using magnitude-based selection
Strategy: keep the 2 largest (by absolute value) in each group of 4. This is NVIDIA’s recommended approach and works well empirically.
Theoretical justification: large weights carry more information. Empirical validation: <1% accuracy drop on ImageNet, BERT, GPT-2.
pub fn sparse_matmul(
sparse_a: Sparse24Tensor,
dense_b: tensor.Tensor,
) -> #(tensor.Tensor, Float)
Sparse matrix multiplication (simulated)
On Tensor Cores, this skips 50% of multiplications by hardware. Here we decompress and multiply densely for correctness. Returns: (result, theoretical_speedup)