viva_tensor/awq

AWQ (Activation-aware Weight Quantization)

MLSys 2024 BEST PAPER AWARD! https://arxiv.org/abs/2306.00978

INSIGHT PRINCIPAL: Apenas ~1% dos pesos são “salientes” - identificados pela magnitude das ATIVAÇÕES, não dos pesos!

ALGORITMO:

  1. Coletar estatísticas de ativação (calibration)
  2. Identificar canais salientes (alta ativação média)
  3. Escalar canais salientes PARA CIMA antes de quantizar
  4. Escalar ativações de entrada PARA BAIXO (matematicamente equivalente)

RESULTADO: Mesma compressão, MUITO menos erro!

Implementação: MIT-HAN Lab + AutoAWQ

Types

Configuração AWQ

pub type AWQConfig {
  AWQConfig(
    bits: Int,
    group_size: Int,
    alpha: Float,
    zero_point: Bool,
  )
}

Constructors

  • AWQConfig(
      bits: Int,
      group_size: Int,
      alpha: Float,
      zero_point: Bool,
    )

    Arguments

    bits

    Bits de quantização (4 é padrão)

    group_size

    Tamanho do grupo para scales

    alpha

    Expoente alpha para scaling (0.5 é típico)

    zero_point

    Usar zero-point (assimétrico)

Scales AWQ computados

pub type AWQScales {
  AWQScales(
    weight_scales: List(Float),
    activation_stats: List(Float),
    alpha: Float,
  )
}

Constructors

  • AWQScales(
      weight_scales: List(Float),
      activation_stats: List(Float),
      alpha: Float,
    )

    Arguments

    weight_scales

    Scales por canal (multiplicador de pesos)

    activation_stats

    Estatísticas de ativação usadas

    alpha

    Alpha usado

Tensor quantizado com AWQ

pub type AWQTensor {
  AWQTensor(
    quantized_weights: List(Int),
    awq_scales: AWQScales,
    quant_scales: List(Float),
    zero_points: List(Int),
    shape: List(Int),
    memory_bytes: Int,
  )
}

Constructors

  • AWQTensor(
      quantized_weights: List(Int),
      awq_scales: AWQScales,
      quant_scales: List(Float),
      zero_points: List(Int),
      shape: List(Int),
      memory_bytes: Int,
    )

    Arguments

    quantized_weights

    Pesos quantizados (INT4)

    awq_scales

    Scales AWQ por canal

    quant_scales

    Scales de quantização por grupo

    zero_points

    Zero-points (se assimétrico)

    shape

    Shape original

    memory_bytes

    Memória em bytes

Values

pub fn apply_activation_transform(
  activations: List(Float),
  scales: AWQScales,
) -> List(Float)

Aplica transformação inversa às ativações X’ = X * diag(1/s) Isso compensa o scaling dos pesos

pub fn apply_weight_transform(
  weights: List(List(Float)),
  scales: AWQScales,
) -> List(List(Float))

Aplica transformação equivalente aos pesos W’ = W * diag(s) Isso escala canais salientes PARA CIMA

pub fn benchmark_awq() -> Nil
pub fn collect_activation_stats(
  activations_batch: List(List(Float)),
) -> List(Float)

Coleta estatísticas de ativação de um batch de calibração Retorna média absoluta por canal

pub fn compute_awq_scales(
  activation_stats: List(Float),
  alpha: Float,
) -> AWQScales

Computa scales AWQ baseado nas estatísticas de ativação scale[i] = activation_stat[i] ^ alpha

pub fn default_config() -> AWQConfig

Configuração padrão AWQ

pub fn dequantize_awq(awq: AWQTensor) -> tensor.Tensor

Dequantiza tensor AWQ

pub fn identify_salient_channels(
  activation_stats: List(Float),
  top_percent: Float,
) -> List(Int)

Identifica canais salientes (top-k por ativação)

pub fn main() -> Nil
pub fn quantize_awq(
  weights: tensor.Tensor,
  calibration_data: List(List(Float)),
  config: AWQConfig,
) -> AWQTensor

Quantiza pesos usando AWQ (pipeline completo)

Search Document