LLM ModelHandle API
viva_tensor exposes a public ModelHandle API for Llama-family HuggingFace
models stored as SafeTensors. It packages the production TinyLlama decode path
into two calls: load the model once, then generate from the cached handle.
The API is designed for local BF16 HF checkpoints with the standard Llama tensor names:
model.embed_tokens.weightmodel.layers.N.self_attn.{q,k,v,o}_proj.weightmodel.layers.N.mlp.{gate,up,down}_proj.weightmodel.layers.N.{input,post_attention}_layernorm.weightmodel.norm.weightlm_head.weight
If config.json is present next to the SafeTensors file, viva_tensor reads
the hidden size, layer count, head count, KV head count, RMSNorm epsilon, RoPE
theta, intermediate size, and vocab size from it. Otherwise it infers what it
can from tensor shapes and uses the TinyLlama-compatible defaults.
Gleam
import viva_tensor as t
pub fn main() {
let assert Ok(model) = t.load_model("tmp/tinyllama/model.safetensors")
let opts =
t.GenerateOpts(
max_new_tokens: 50,
temperature: 0.0,
top_k: t.TopKInfinity,
top_p: 1.0,
seed: 42,
stop_on_eos: True,
)
let assert Ok(result) = t.generate(model, "Hello", opts)
result.text
}
temperature: 0.0 uses the fused argmax decode-step NIF for byte-identical
reproducibility. temperature > 0.0 uses fused top-k logits plus host
temperature, top-k, top-p, and seeded multinomial sampling.
For reproducible sampling:
let opts =
t.GenerateOpts(
max_new_tokens: 20,
temperature: 0.8,
top_k: t.TopK(40),
top_p: 0.95,
seed: 42,
stop_on_eos: True,
)
let assert Ok(result) = t.generate(model, "Hello", opts)
Erlang
{ok, Model} = viva_tensor_llm:load(
<<"tmp/tinyllama/model.safetensors">>,
#{block_size => 16}
),
{ok, Result} = viva_tensor_llm:generate(
Model,
<<"Hello">>,
#{max_new_tokens => 50, temperature => 0.0}
),
#{tokens := Tokens,
text := Text,
ms_per_token := MsPerToken,
total_tokens := TotalTokens} = Result.
Load Options
viva_tensor_llm:load/2 accepts:
| Option | Default | Notes |
|---|---|---|
num_layers | detected from SafeTensors / config.json | Number of decoder blocks to load. |
block_size | 16 | FP8 blocked prepack size used by the decode-step path. |
tokenizer_path | <model>_tokenizer.json, then sibling tokenizer.json fallback | HF tokenizer JSON. |
Generation Options
viva_tensor_llm:generate/3 accepts:
| Option | Default | Notes |
|---|---|---|
max_new_tokens | 50 | Maximum generated tokens. |
temperature | 0.0 | 0.0 keeps the argmax path and absolute reproducibility; values above zero enable sampling. |
top_k | infinity | Sampling candidate cap. infinity uses up to 256 fused top-k logits; explicit values are capped at 256. |
top_p | 1.0 | Nucleus sampling probability applied over the fused candidate set. |
seed | 42 | Deterministic seed; the same prompt, model, and options reproduce the same sampled tokens. |
stop_on_eos | true | Stop after emitting EOS. |
Cached vs Per Call
The ModelHandle caches:
- tokenizer state
- BF16 or F16 embedding table as a native resource
- all layer weights prepacked with blocked FP8 scales
- fused QKV and gate-up packed weights
- final RMSNorm bytes
- packed
lm_head - RoPE frequency bytes
- model metadata from
config.jsonand tensor shapes
Each generate call allocates fresh KV caches before prefill. KV cache
resources are mutable during decode, so they are intentionally per call to keep
one ModelHandle reusable across prompts.
Tested models
| Model | Status | Decode speed | Notes |
|---|---|---|---|
| TinyLlama-1.1B-Chat-v1.0 | validated | 2.31 ms/token | head_dim=64, GQA fast path, byte-level BPE tokenizer. |
| Llama-3.2-1B-Instruct | validated | 2.47 ms/token | sharded SafeTensors, tied embeddings / lm_head, Llama-3 tokenizer path. |
| NousResearch/Llama-2-7b-chat-hf | validated | 113.18 ms/token | sharded F16 SafeTensors, head_dim=128, no GQA; exercises the dynamic CUDA fallback path. |
The same public API drives both models:
let assert Ok(model) = t.load_model("tmp/llama32_1b/model-00001-of-00002.safetensors")
let opts = t.default_generate_opts()
let assert Ok(result) = t.generate(model, "Hello", opts)
Performance
On an RTX 4090, the current public handle API has been validated at
2.31 ms/token for TinyLlama-1.1B, 2.47 ms/token for
Llama-3.2-1B-Instruct, and 113.18 ms/token for
NousResearch/Llama-2-7b-chat-hf. The Llama-2-7B run is functional and coherent,
but much slower because it exercises the current head_dim=128 dynamic path.
A best TinyLlama FP8 W8A16 decode run reaches 448 tok/s, ahead of the local
Ollama baseline at 352 tok/s.
Generation still calls nt_forward_decode_step/8 once per decoded token.
Prefill is also token-by-token today; a batched prefill path is future work.
Limitations
- Phi-2 is not a drop-in target. Its architecture and tensor naming diverge
from the Llama-family loader contract used by
ModelHandle. - Llama-2-7B uses the slow dynamic attention path today.
NousResearch/Llama-2-7b-chat-hfvalidates sharded F16 loading andhead_dim=128correctness, but decode throughput is not optimized yet. - No batched prefill path yet. The decode kernel is optimized for
batch=1; batched prompt processing is still expressed as repeated decode-step calls. - True FP8xFP8 is not used for LLM decode. The numerically validated path is W8A16 with blocked FP8 weights. Quantizing the single-token activation would save only a few KB per token while risking the argmax/EOS behavior already validated on TinyLlama and Llama-3.2-1B.