viva_tensor/core/config

Configuration Types with Builder Pattern

Builder pattern: Gang of Four called, they want royalties. But seriously, this beats magic numbers in function calls any day.

The idea: sensible defaults + fluent customization. You get a working config out of the box, then override only what you need. No more googling “what’s the default stride for conv2d in pytorch”.

Design Philosophy

  1. Defaults are production-ready, not toy examples
  2. Labelled arguments for required params, builders for optional ones
  3. Every config should be printable for debugging (thanks Gleam!)

Example

import viva_tensor/core/config

// Start with defaults, customize what matters
let cfg = config.conv2d()
  |> config.with_stride(2)
  |> config.with_padding(1)

// Or be explicit about everything
let cfg = config.conv2d_new(
  kernel_h: 5,
  kernel_w: 5,
  stride: 2,
  padding: 1,
)

Types

AWQ (Activation-Aware Weight Quantization) configuration.

Lin et al. (2023) key insight: not all weights are equal. ~1% of weights cause ~99% of quantization error - the “salient” channels. These correspond to input channels with large activation magnitudes.

Solution: scale up salient channels before quantization, scale down after. Effectively gives them more bits of precision where it matters.

Requires calibration data to identify salient channels.

pub type AWQConfig {
  AWQConfig(
    block_size: Int,
    n_calibration: Int,
    scale_factor: Float,
  )
}

Constructors

  • AWQConfig(
      block_size: Int,
      n_calibration: Int,
      scale_factor: Float,
    )

    Arguments

    block_size

    Block size for underlying 4-bit quantization

    n_calibration

    Number of calibration samples to collect activation statistics. More = better saliency estimation, but diminishing returns after ~128.

    scale_factor

    Scaling factor for salient channels. Higher = more protection for important weights, but can cause numerical issues if too high. Paper uses grid search to find optimal value per layer.

Flash Attention configuration.

Dao et al. (2022) “FlashAttention: Fast and Memory-Efficient Exact Attention” The algorithm that made long-context LLMs practical.

Key insight: attention is memory-bound, not compute-bound. By tiling the computation to fit in SRAM, we avoid reading/writing the O(n^2) attention matrix to HBM. 2-4x faster, O(n) memory.

This config is for the attention parameters, not the Flash algorithm itself. The algorithm is handled in the implementation.

pub type AttentionConfig {
  AttentionConfig(
    num_heads: Int,
    head_dim: Int,
    dropout: Float,
    causal: Bool,
    scale: Float,
  )
}

Constructors

  • AttentionConfig(
      num_heads: Int,
      head_dim: Int,
      dropout: Float,
      causal: Bool,
      scale: Float,
    )

    Arguments

    num_heads

    Number of attention heads. More heads = more expressive, more memory. GPT-3: 96 heads. BERT-base: 12 heads. Choose based on model size.

    head_dim

    Dimension of each head (d_k = d_model / num_heads typically). 64 is common. Must divide evenly into total embedding dimension.

    dropout

    Dropout probability on attention weights. 0.0 for inference. Training typically uses 0.1. Higher for small datasets to prevent overfitting.

    causal

    Causal (autoregressive) masking. True for decoder-only models (GPT). Prevents attending to future tokens. False for encoder models (BERT).

    scale

    Softmax scaling factor. Typically 1/sqrt(head_dim). The Vaswani et al. (2017) insight that prevents gradient vanishing.

Conv2D operation configuration.

Supports all the usual suspects: stride, padding, dilation, groups. Asymmetric values supported (different H and W) because sometimes your input isn’t square and you shouldn’t have to pretend it is.

pub type Conv2dConfig {
  Conv2dConfig(
    kernel_h: Int,
    kernel_w: Int,
    stride_h: Int,
    stride_w: Int,
    padding_h: Int,
    padding_w: Int,
    dilation_h: Int,
    dilation_w: Int,
    groups: Int,
  )
}

Constructors

  • Conv2dConfig(
      kernel_h: Int,
      kernel_w: Int,
      stride_h: Int,
      stride_w: Int,
      padding_h: Int,
      padding_w: Int,
      dilation_h: Int,
      dilation_w: Int,
      groups: Int,
    )

    Arguments

    kernel_h

    Kernel height - typically 1, 3, 5, or 7. 3x3 is the ResNet sweet spot.

    kernel_w

    Kernel width - usually same as height, but asymmetric kernels exist

    stride_h

    Stride height - controls output spatial reduction. 2 = halve the size.

    stride_w

    Stride width

    padding_h

    Padding height - add zeros around input. padding=kernel/2 keeps size.

    padding_w

    Padding width

    dilation_h

    Dilation height - “atrous” convolution, increases receptive field without adding parameters. DeepLab loves this, value 1 = standard conv.

    dilation_w

    Dilation width

    groups

    Groups for grouped/depthwise convolution. groups=in_channels is depthwise. MobileNet’s secret sauce for efficient inference.

INT8 quantization configuration.

The production workhorse. Jacob et al. (2017) showed INT8 inference loses <1% accuracy on ImageNet while being 4x smaller and up to 4x faster.

Two modes:

  • Per-tensor (block_size=0): One scale for the whole tensor. Fastest.
  • Per-block: One scale per block of N elements. More accurate for outliers.
pub type Int8Config {
  Int8Config(block_size: Int, symmetric: Bool)
}

Constructors

  • Int8Config(block_size: Int, symmetric: Bool)

    Arguments

    block_size

    Block size for per-block quantization. 0 = per-tensor (one global scale). Per-block helps when tensor has outliers in specific regions.

    symmetric

    Symmetric quantization: range is [-127, 127], zero maps to zero. Asymmetric: range is [0, 255] with a zero-point offset. Symmetric is simpler and faster, asymmetric handles skewed distributions.

NF4 (NormalFloat4) quantization configuration.

From Dettmers et al. (2023) “QLoRA: Efficient Finetuning of Quantized LLMs”

Key parameters:

  • block_size: Number of weights sharing one scale factor. Smaller = more accurate but more overhead. 64 is the sweet spot from the paper.
  • double_quant: Quantize the scale factors themselves (FP32 -> FP8). Saves 0.37 bits/param with negligible quality loss. Free lunch.

Memory math for a 7B model: FP16: 7B * 2 bytes = 14GB NF4: 7B * 0.5 bytes + scales = ~3.5GB NF4 + double_quant: ~3.1GB

pub type NF4Config {
  NF4Config(block_size: Int, double_quant: Bool)
}

Constructors

  • NF4Config(block_size: Int, double_quant: Bool)

    Arguments

    block_size

    Block size for quantization. Each block of N weights shares one scale. Smaller = better quality, more memory overhead. Paper uses 64, which gives ~0.5 bits overhead per weight.

    double_quant

    Double quantization: quantize the FP32 scales to FP8. Saves 0.37 bits/param with no measurable quality loss. No reason not to use this.

Pooling operation configuration.

MaxPool: take the max in each window. Good for translation invariance. AvgPool: take the mean. Smoother gradients, sometimes better for deep networks. GlobalAvgPool: pool_size = input_size. One value per channel. Classification head.

pub type PoolConfig {
  PoolConfig(
    pool_h: Int,
    pool_w: Int,
    stride_h: Int,
    stride_w: Int,
    padding_h: Int,
    padding_w: Int,
  )
}

Constructors

  • PoolConfig(
      pool_h: Int,
      pool_w: Int,
      stride_h: Int,
      stride_w: Int,
      padding_h: Int,
      padding_w: Int,
    )

    Arguments

    pool_h

    Pool window height

    pool_w

    Pool window width

    stride_h

    Stride height (typically = pool_h for non-overlapping)

    stride_w

    Stride width

    padding_h

    Padding height (usually 0 for pooling)

    padding_w

    Padding width

Values

pub fn attention(
  num_heads num_heads: Int,
  head_dim head_dim: Int,
) -> AttentionConfig

Create attention config with required parameters.

Scale is automatically set to 1/sqrt(head_dim) per Vaswani et al. (2017). Override with attention_with_scale() if you’re doing something exotic (e.g., cosine attention, ALiBi without scaling).

pub fn attention_causal(
  config: AttentionConfig,
) -> AttentionConfig

Enable causal (autoregressive) masking.

For decoder-only models (GPT, LLaMA, etc.) that should only attend to past tokens, not future ones. Implemented as a triangular mask.

pub fn attention_with_dropout(
  config: AttentionConfig,
  dropout: Float,
) -> AttentionConfig

Set dropout probability on attention weights.

0.0 for inference (always). 0.1 is typical for training. Higher (0.2-0.3) for small datasets or aggressive regularization.

pub fn attention_with_scale(
  config: AttentionConfig,
  scale: Float,
) -> AttentionConfig

Override the softmax scaling factor.

Default is 1/sqrt(head_dim) which prevents attention logits from growing too large as dimension increases. You might override this for:

  • Cosine attention (scale=1, use normalized Q and K)
  • ALiBi without scaling (Press et al., 2022)
  • Experimental attention variants
pub fn awq() -> AWQConfig

Default AWQ: block_size=64, 128 calibration samples, scale=1.0.

scale_factor=1.0 is a placeholder - in practice you’d tune this per-layer using calibration data. See the paper for the grid search procedure.

pub fn awq_with_block_size(
  config: AWQConfig,
  block_size: Int,
) -> AWQConfig

Set AWQ block size.

pub fn awq_with_calibration(
  config: AWQConfig,
  n: Int,
) -> AWQConfig

Set number of calibration samples for saliency estimation.

pub fn conv2d() -> Conv2dConfig

Default Conv2d: 3x3 kernel, stride 1, no padding, no dilation.

Why 3x3? VGGNet (2014) showed stacking 3x3s beats larger kernels. Two 3x3s have the same receptive field as one 5x5 but fewer params. Three 3x3s = one 7x7 receptive field. This is why ResNet uses 3x3 everywhere.

Warning: stride=1 + no padding shrinks output by (kernel-1) pixels per side. For “same” output size, use conv2d_same() or add padding=kernel/2.

pub fn conv2d_new(
  kernel_h kernel_h: Int,
  kernel_w kernel_w: Int,
  stride stride: Int,
  padding padding: Int,
) -> Conv2dConfig

Explicit Conv2d config with labelled arguments.

Use this when you know exactly what you want and don’t need the builder pattern’s incremental customization.

pub fn conv2d_same(kernel_h: Int, kernel_w: Int) -> Conv2dConfig

“Same” padding: output spatial size equals input spatial size.

Computes padding = kernel_size / 2 (integer division). Only works correctly for odd kernel sizes with stride=1. For even kernels or stride>1, you need asymmetric padding (not supported here).

Note: PyTorch’s “same” padding does asymmetric padding. We keep it simple.

pub fn int8() -> Int8Config

Default INT8: per-tensor symmetric quantization.

Simplest and fastest. Works well when weight distributions are roughly symmetric around zero (which they usually are for trained models).

pub fn int8_with_block_size(
  config: Int8Config,
  block_size: Int,
) -> Int8Config

Set INT8 block size for per-block quantization. 0 = per-tensor quantization (default, fastest).

pub fn nf4() -> NF4Config

Default NF4: block_size=64, double_quant=True (paper settings).

These are the settings from QLoRA that achieved results matching full-precision finetuning. Don’t change unless you know why.

pub fn nf4_with_block_size(
  config: NF4Config,
  block_size: Int,
) -> NF4Config

Override NF4 block size.

Smaller = better quality, more scale overhead. 32: ~0.69 bits/param overhead, slightly better quality 64: ~0.5 bits/param overhead (default, paper setting) 128: ~0.44 bits/param overhead, slightly worse quality

pub fn nf4_with_double_quant(
  config: NF4Config,
  enabled: Bool,
) -> NF4Config

Enable/disable double quantization. There’s really no reason to disable this, but the option exists.

pub fn pool() -> PoolConfig

Default pooling: 2x2 window, stride 2, no padding.

The classic maxpool config: halves spatial dimensions. Non-overlapping windows (stride = pool_size).

Fun fact: Hinton thinks pooling is a mistake because it throws away spatial information. Capsule networks are his proposed fix. The jury’s still out.

pub fn pool_new(
  pool_size pool_size: Int,
  stride stride: Int,
) -> PoolConfig

Create pool config with explicit size and stride.

pub fn pool_with_padding(
  config: PoolConfig,
  padding: Int,
) -> PoolConfig

Set pool padding.

pub fn pool_with_size(
  config: PoolConfig,
  pool_h: Int,
  pool_w: Int,
) -> PoolConfig

Set pool window size.

pub fn pool_with_stride(
  config: PoolConfig,
  stride: Int,
) -> PoolConfig

Set pool stride.

pub fn with_dilation(
  config: Conv2dConfig,
  dilation: Int,
) -> Conv2dConfig

Set uniform dilation for atrous/dilated convolution.

dilation=2 means skip every other pixel when applying kernel. Effective kernel size = dilation * (kernel - 1) + 1 So 3x3 with dilation=2 has 5x5 receptive field but only 9 params.

DeepLab uses this for semantic segmentation - large receptive field without the parameter explosion of large kernels.

pub fn with_groups(
  config: Conv2dConfig,
  groups: Int,
) -> Conv2dConfig

Set number of groups for grouped convolution.

groups=1: standard convolution, all inputs connect to all outputs groups=in_channels: depthwise convolution (MobileNet, EfficientNet) groups=N: grouped convolution, splits channels into N independent groups

Depthwise + 1x1 pointwise = depthwise separable convolution Cuts compute by ~kernel_size^2 with minimal accuracy loss.

pub fn with_kernel(
  config: Conv2dConfig,
  kernel_h: Int,
  kernel_w: Int,
) -> Conv2dConfig

Set kernel size (height and width).

pub fn with_padding(
  config: Conv2dConfig,
  padding: Int,
) -> Conv2dConfig

Set uniform padding.

padding=1 with 3x3 kernel maintains spatial size (stride=1). padding=2 with 5x5 kernel maintains spatial size (stride=1). General rule: padding = (kernel_size - 1) / 2 for “same” output.

pub fn with_padding_hw(
  config: Conv2dConfig,
  padding_h: Int,
  padding_w: Int,
) -> Conv2dConfig

Set separate paddings for height and width.

pub fn with_stride(
  config: Conv2dConfig,
  stride: Int,
) -> Conv2dConfig

Set uniform stride (same for H and W).

stride=2 is the standard way to downsample - halves spatial dimensions. More efficient than conv+maxpool and learns the downsampling.

pub fn with_stride_hw(
  config: Conv2dConfig,
  stride_h: Int,
  stride_w: Int,
) -> Conv2dConfig

Set separate strides for height and width. Rarely needed, but here for completeness.

Search Document