viva_tensor/awq
AWQ (Activation-aware Weight Quantization)
MLSys 2024 BEST PAPER AWARD! https://arxiv.org/abs/2306.00978
INSIGHT PRINCIPAL: Apenas ~1% dos pesos são “salientes” - identificados pela magnitude das ATIVAÇÕES, não dos pesos!
ALGORITMO:
- Coletar estatísticas de ativação (calibration)
- Identificar canais salientes (alta ativação média)
- Escalar canais salientes PARA CIMA antes de quantizar
- Escalar ativações de entrada PARA BAIXO (matematicamente equivalente)
RESULTADO: Mesma compressão, MUITO menos erro!
Implementação: MIT-HAN Lab + AutoAWQ
Types
Configuração AWQ
pub type AWQConfig {
AWQConfig(
bits: Int,
group_size: Int,
alpha: Float,
zero_point: Bool,
)
}
Constructors
-
AWQConfig( bits: Int, group_size: Int, alpha: Float, zero_point: Bool, )Arguments
- bits
-
Bits de quantização (4 é padrão)
- group_size
-
Tamanho do grupo para scales
- alpha
-
Expoente alpha para scaling (0.5 é típico)
- zero_point
-
Usar zero-point (assimétrico)
Scales AWQ computados
pub type AWQScales {
AWQScales(
weight_scales: List(Float),
activation_stats: List(Float),
alpha: Float,
)
}
Constructors
-
AWQScales( weight_scales: List(Float), activation_stats: List(Float), alpha: Float, )Arguments
- weight_scales
-
Scales por canal (multiplicador de pesos)
- activation_stats
-
Estatísticas de ativação usadas
- alpha
-
Alpha usado
Tensor quantizado com AWQ
pub type AWQTensor {
AWQTensor(
quantized_weights: List(Int),
awq_scales: AWQScales,
quant_scales: List(Float),
zero_points: List(Int),
shape: List(Int),
memory_bytes: Int,
)
}
Constructors
-
AWQTensor( quantized_weights: List(Int), awq_scales: AWQScales, quant_scales: List(Float), zero_points: List(Int), shape: List(Int), memory_bytes: Int, )Arguments
- quantized_weights
-
Pesos quantizados (INT4)
- awq_scales
-
Scales AWQ por canal
- quant_scales
-
Scales de quantização por grupo
- zero_points
-
Zero-points (se assimétrico)
- shape
-
Shape original
- memory_bytes
-
Memória em bytes
Values
pub fn apply_activation_transform(
activations: List(Float),
scales: AWQScales,
) -> List(Float)
Aplica transformação inversa às ativações X’ = X * diag(1/s) Isso compensa o scaling dos pesos
pub fn apply_weight_transform(
weights: List(List(Float)),
scales: AWQScales,
) -> List(List(Float))
Aplica transformação equivalente aos pesos W’ = W * diag(s) Isso escala canais salientes PARA CIMA
pub fn benchmark_awq() -> Nil
pub fn collect_activation_stats(
activations_batch: List(List(Float)),
) -> List(Float)
Coleta estatísticas de ativação de um batch de calibração Retorna média absoluta por canal
pub fn compute_awq_scales(
activation_stats: List(Float),
alpha: Float,
) -> AWQScales
Computa scales AWQ baseado nas estatísticas de ativação scale[i] = activation_stat[i] ^ alpha
pub fn identify_salient_channels(
activation_stats: List(Float),
top_percent: Float,
) -> List(Int)
Identifica canais salientes (top-k por ativação)
pub fn quantize_awq(
weights: tensor.Tensor,
calibration_data: List(List(Float)),
config: AWQConfig,
) -> AWQTensor
Quantiza pesos usando AWQ (pipeline completo)