From 93 to 1222 tok/s: Serving IBM Granite 4.1 on Dual RTX 3090s

A full-stack tour of three rounds on the same dual-3090 host that ran our Qwen3.6 work — first the 8B dense (1222 tok/s aggregate, 2624 peak), then the 30B dense (216 tok/s with a quantization-format gotcha that breaks one INT4 variant entirely), and finally a trained-from-scratch EAGLE-3 speculative-decoding head that buys +13.5% single-stream throughput and ships in 30 minutes of training.

TL;DR. On dual RTX 3090s we built three production stacks for IBM's Granite 4.1 family:
Round 1 — Granite-4.1-8B (dense AWQ-INT4): 105 tok/s single-stream, 1222 tok/s aggregate at C=64 via 2 replicas behind nginx (peak 2624 tok/s burst). At C=4 it lands 307 tok/s with 11 ms TPOT, beating both Qwen3.6 stacks on identical hardware.
Round 2 — Granite-4.1-30B (dense INT4): 216 tok/s aggregate at C=8 via TP=2. The cyankiwi asymmetric INT4 group_size=32 quant produces gibberish on Ampere; the drawais symmetric INT4 group_size=128 quant runs cleanly. TP=2 wins over the 2-replica LB pattern at this size — an inversion of the Qwen3.6-27B finding driven entirely by VRAM headroom for max-num-seqs.
Round 3 — EAGLE-3 head trained on Granite-4.1-3B: +13.5% single-stream tok/s and -13% TPOT at C=1 from a head we trained ourselves in 30 minutes (Granite ships no MTP head, unlike Qwen3.6). The head fits in 380 MB and slots into vLLM with two small patches.

The hardware

GPUs
2× NVIDIA RTX 3090 (24 GB each, 48 GB total, Ampere SM 8.6)
Host
Ubuntu 22.04, kernel 6.8.0 (HWE)
Driver
580.159.03 (CUDA 13.0)
Image
vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08 (vLLM 0.19.2)
Models
Granite-4.1-{3B, 8B, 30B} dense — pure GQA transformer, no Mamba/SSM/hybrid attention

Granite 4.1's language models are interesting for the same reason Qwen3.6-27B was uninteresting in our last write-up: they're plain dense transformers. No DeltaNet, no Mamba-2, no hybrid attention layers, no boundary-protection guards. The entire Genesis-patches detour we needed for Qwen3.6 (TurboQuant rejecting hybrid models, custom KV cache code paths) doesn't exist here. AWQ-INT4 just works on every layer.

That same simplicity, though, costs you Qwen's MTP head. Granite ships no built-in speculative-decoding draft. So Round 3 picks up where Qwen's "MTP n=5" optimization left off — by training the equivalent ourselves.

Architecture — the final 8B stack

client
Agent / API
POST /v1/...
load balancer
nginx :8400
least_conn
replica #1
vLLM :8600
GPU 0 · RTX 3090
replica #2
vLLM :8601
GPU 1 · RTX 3090
model
Granite-4.1-8B
cyankiwi AWQ-INT4
5.4 GB · max-num-seqs=32

Two independent vLLM replicas, no NCCL between them, behind a least-connections nginx proxy. Same shape as our Qwen3.6-27B Round 1 — but the model is small enough that each replica absorbs 32 concurrent streams before saturating its 3090, so the whole stack runs C=64 productively.

Round 1 — Granite-4.1-8B (dense, AWQ-INT4)

Stage 1 → 2: pick the right quant on Ampere

FP8 = gibberishAWQ-INT4 = correct

The headline IBM checkpoint is ibm-granite/granite-4.1-8b-fp8 — official, signed, 9 GB on disk. It boots cleanly on a 3090. It produces gibberish.

# curl ... -d '{"messages":[{"role":"user","content":"Capital of France?"}]}'
{"choices":[{"message":{"content":"                    "}}]}

vLLM logs explain why:

WARNING [marlin_utils_fp8.py:97] Your GPU does not have native support for FP8
computation but FP8 quantization is being used. Weight-only FP8 compression
will be used leveraging the Marlin kernel. This may degrade performance for
compute-heavy workloads.

Hopper and later have native FP8 tensor cores. Ampere SM 8.6 doesn't, so vLLM falls back to a Marlin weight-only FP8 path that dequantizes to BF16 for the matmul. For Granite-4.1 specifically, that dequant emits numerically wrong scales and the model hidden states drift to a degenerate fixed point — usually generating EOS-like tokens or whitespace.

The fix is to skip FP8 entirely on Ampere and use a community AWQ-INT4 checkpoint where the Marlin INT4 path is mature and well-tested:

docker run -d --name granite-8b --gpus '"device=0"' \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -p 8600:8000 --ipc=host --shm-size=16gb \
  -e VLLM_NO_USAGE_STATS=1 -e HF_HUB_OFFLINE=1 \
  -e VLLM_USE_FLASHINFER_SAMPLER=1 \
  vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08 \
  --model cyankiwi/granite-4.1-8b-AWQ-INT4 --served-model-name granite-8b \
  --tensor-parallel-size 1 --max-model-len 32768 \
  --gpu-memory-utilization 0.92 \
  --max-num-seqs 8 --max-num-batched-tokens 4096 \
  --enable-prefix-caching --enable-chunked-prefill

cyankiwi/granite-4.1-8b-AWQ-INT4 is 5.4 GB on disk (vs FP8's 9 GB), runs the Marlin INT4 forward kernel that's been the workhorse for Qwen, Llama, Mistral on Ampere for over a year, and produces correct outputs on the first request:

> Capital of France?
< Paris

A second small gotcha: --kv-cache-dtype fp8_e5m2 is rejected by vLLM 0.19 nightly with the message "fp8_e5m2 kv-cache is not supported with fp8 checkpoints". The guard is over-broad — it fires on any compressed-tensors checkpoint, including INT4. Drop the flag. The auto (BF16) KV cache is fine; the 8B model is small enough that KV memory isn't the bottleneck.

Stage 3: scaling max-num-seqs on a single 3090

105 tok/s C=11052 tok/s C=32

The 8B at INT4 is 5.4 GB of weights — leaves ~18 GB on a 3090 for KV cache and activations. That headroom is enormous for a model this size, and max-num-seqs was the lever to use it.

Ctok/s sustainedtok/s peakTPOT
11051198.5 ms
437442010.3 ms
867673611.8 ms
16812116817.0 ms
321052150428.0 ms

With --max-num-seqs 32 --max-num-batched-tokens 4096, a single 3090 sustains over 1000 tok/s aggregate at C=32. TPOT climbs from 8.5 ms (C=1) to 28 ms (C=32), but per-stream latency stays interactive throughout. For agentic systems the C=4–8 sweet spot lands at 374–676 tok/s with TPOT under 12 ms.

Stage 4: 2-replica + nginx LB (1222 tok/s aggregate)

1052 tok/s (1 GPU)1222 tok/s sustained / 2624 peak (2 GPUs)

Going to two replicas, one per GPU, with a least-connections nginx LB. Aggregate sweep through the LB at port 8400, same prompt shape (1024 in / 1024 out):

8B-AWQ-INT4 · 2 replicas + nginx LB · aggregate output tok/s 1024-token random input · 1024-token output · sustained vs peak burst 500 1000 1500 2000 2500 0 tokens / second C = 1 105 / 121 C = 2 186 / 228 C = 4 307 / 436 C = 8 477 / 712 C = 16 757 / 1272 C = 32 984 / 1974 C = 64 1222 / 2624 sustained tok/s final / peak sustained peak burst
The 8B fits comfortably on each card with max-num-seqs=32, so total concurrency scales linearly with replicas up to C=64. Animation plays once on page load.

For comparison, our Qwen3.6 blog peaked at:

Granite-4.1-8B at C=4 hits 307 tok/s with TPOT 11 ms — and keeps scaling. At C=64 it lands 1222 tok/s sustained, 2624 tok/s peak burst. The Qwen stacks were already saturated at C=4–8; the Granite-8B has 4× more concurrency headroom because its 5.4 GB weight footprint leaves room for max-num-seqs=32 per replica without KV pressure.

The architectural lesson: when the model is small enough to fit comfortably with a generous KV budget, the 2-replica LB pattern doesn't just give you fault isolation — it gives you a much higher concurrency ceiling than any single-instance topology. The weight-bytes-vs-VRAM ratio is what you optimize.

Round 2 — Granite-4.1-30B (dense, INT4)

The 8B is the workhorse, but the 30B is what you reach for when you need quality. Two AWQ-INT4 community checkpoints exist:

CheckpointSizeQuant scheme
cyankiwi/granite-4.1-30b-AWQ-INT419.9 GBasymmetric INT4, group_size=32, sharded
drawais/Granite-4.1-30B-AWQ-INT415.5 GBsymmetric INT4, group_size=128, single safetensors

Both are valid INT4 quantizations of the same model weights. Only one of them gives correct output on Ampere.

The quant-format finding

The cyankiwi checkpoint loads cleanly, runs at TP=2, and produces:

> Capital of Japan in one word?
< noun noun nounLorem Ipsum Lorem ipsum LoremLorem Lorem

The drawais checkpoint, loaded with the same flags, produces:

> Capital of Japan in one word?
< Tokyo

The fault isn't quant bit-width — both are INT4. It's quant parameters. The Marlin INT4 kernel on Ampere has different code paths for symmetric vs asymmetric, and for different group sizes. Asymmetric INT4 with group_size=32 hits a path that's numerically broken for granite-30B's specific layer-shape combination (64 layers, intermediate_size=32768). Symmetric INT4 with group_size=128 takes a different, correct path.

The takeaway is never trust an INT4 checkpoint on Ampere without verifying output. A clean boot and successful profile_run mean nothing — the kernel can be wrong and produce coherent-looking gibberish.

TP=2 vs LB vs single-card

The drawais 30B at INT4 is 15.5 GB. We measured all three topologies on the same model, same prompt shape, same vLLM:

TopologyC=4 tok/sC=8 tok/sTPOT C=4
Single 3090, max-num-seqs=416424.4 ms
2 replicas + LB, max-num-seqs=4 each8416834.9 ms
TP=2, max-num-seqs=813821629.0 ms

For comparison we also ran the official ibm-granite/granite-4.1-30b-fp8 checkpoint (which, unlike the 8B FP8, does produce correct output on Ampere — different layer count, different code path):

TopologyC=4 tok/sC=8 tok/s
30b-fp8 TP=288.5148.6

The FP8 checkpoint is 45–80% slower than the INT4 path — same Hopper-vs-Ampere mismatch as the 8B. AWQ-INT4 is the clear winner on Ampere for the 30B as well.

Why TP=2 wins for 30B (and lost for 27B-dense in the Qwen blog)

In the Qwen3.6-27B work, TP=2 lost decisively to "2 replicas + nginx LB" because PCIe-Gen4 NCCL all-reduces dominated. For Granite-30B on the same hardware, TP=2 wins. Same hardware, same NCCL, different result. Why?

The answer is max-num-seqs headroom:

ModelINT4 weight bytesPer-card slack at TP=1Max per-replica seqsMax TP=2 seqs
Qwen3.6-27B-dense~14 GB~9 GB for KV22
Granite-4.1-30B-dense~15.5 GB~7 GB for KV48
Granite-4.1-8B-dense~5.4 GB~17 GB for KV3232

The 30B sits in an awkward valley: weights are big enough that a single-card replica caps at max-num-seqs=4 (no VRAM for a wider KV), but max-num-seqs=4 doesn't saturate compute. TP=2 splits weights to ~7.75 GB per card, frees enough VRAM to run max-num-seqs=8, and that doubled per-step batch outweighs the cross-GPU NCCL traffic.

For the smaller 8B (5.4 GB weights), neither replica is VRAM-constrained — both can run max-num-seqs=32 independently — so two replicas with no NCCL win. For the larger 30B, TP=2 wins because freeing batch headroom is worth more than the NCCL tax.

The right topology is a function of (weight_bytes, KV_per_seq, GPU_capacity). When weights >> 0.5× VRAM, TP=2 wins. When weights << 0.5× VRAM, 2 replicas + LB wins. The Qwen-blog "TP=2 always loses on PCIe-3090" finding was an artifact of measuring at one weight-byte regime.

Round 3 — Training an EAGLE-3 head from scratch

The Qwen3.6 blog's biggest single optimization was MTP n=5 speculative decoding: extra prediction heads ship with the model, vLLM uses them as a draft, and every accepted token is essentially free. From 64 → 100 tok/s C=1 on the dense 27B, just by turning a flag on.

Granite 4.1 ships no MTP head. Or EAGLE head. Or Medusa head. Just a tokenizer and a single causal LM. So we trained one ourselves.

Why EAGLE-3 (over draft-model / ngram / Medusa)

We tried draft-model speculation first — granite-4.1-3b BF16 drafting granite-4.1-8b AWQ-INT4 — but vLLM's draft loader inherits the target's compressed-tensors quant config and tries to load the BF16 3B weights as packed-INT4. Result: 158 weight names fail to initialize. To get this working you'd need a same-quant draft (an AWQ-INT4 3B doesn't exist publicly).

Plain ngram speculation gives a small win on the 8B:

ConfigC=1 tok/sC=4 tok/s
8B-AWQ baseline105307
8B-AWQ + ngram(n=4)113325

A free +6–7%. Useful for code completion or structured output where prompt-prefix copies are common, but the upper bound is low because random natural language doesn't repeat token sequences much.

EAGLE-3 is the right shape for chat: a tiny (single-layer) auto-regressive head trained against the target's hidden states, then loaded as the draft. Acceptance rates of 0.4–0.6 are reported in the EAGLE-3 paper for fully-trained heads, corresponding to 1.5–2× single-stream speedup. Critically, EAGLE-3 heads can be trained in hours on consumer GPUs.

Adapting SpecForge for Granite 4.1

We used sgl-project/SpecForge (the modern EAGLE-3 trainer recommended by the EAGLE authors). Three small adaptations:

1. Add a granite chat template to specforge/data/template.py:

TEMPLATE_REGISTRY.register(
    name="granite",
    template=ChatTemplate(
        assistant_header="<|start_of_role|>assistant<|end_of_role|>",
        user_header="<|start_of_role|>user<|end_of_role|>",
        system_prompt="You are a helpful AI assistant.",
        end_of_turn_token="<|end_of_text|>",
    ),
)

2. Make all sglang imports lazy in three SpecForge modules. SpecForge's HF target backend doesn't actually need SGLang at runtime, but its imports are eager at module load. try / except ImportError around them lets us run with just transformers + torch.

3. Write a granite-4.1-3b-eagle3.json head config — a llama-shaped EAGLE-3 head sized for the 3B target (hidden=2560, intermediate=8192, kv_heads=8, vocab=100352, draft_vocab=32000). On our 5k-sample ShareGPT subset, the top 32k tokens covered 98.36% of training tokens — basically free coverage from the partial-vocab trick.

Training: 30 minutes, single GPU

SpecForge's online HF backend generates target hidden states on the fly per batch — no separate pre-extraction step. The 3B target fits comfortably in 24 GB BF16 alongside the EAGLE head and optimizer state.

The acceptance trajectory:

EAGLE-3 head training — per-step acceptance over training Mean acceptance over last 300 steps · 5000-step single-GPU run 0.10 0.20 0.30 0.40 0 acceptance rate 0 1000 2000 3000 4000 5000 training step 0.50 — paper "fully trained" 0.10 @ step 200 0.24 @ step 1500 0.32 @ step 5000
5000 steps × 1024 tokens × 1 epoch × 5k ShareGPT samples = 30 minutes wall on a single 3090. EAGLE-3 paper's 0.4–0.6 regime is at 50k samples × 10 epochs — an overnight reach.
StepMean acc (last 300)
2000.10
15000.24
50000.32

Acceptance is the per-step probability that a single drafted token gets accepted by the target. At 0.32, with num_speculative_tokens=5, the expected number of accepted tokens per draft cycle is approximately 0.32 + 0.32² + ... + 0.32⁵ ≈ 0.47 extra tokens per verification. That predicts a ~1.13–1.20× single-stream speedup, which is what we measure.

Patching vLLM to accept a Granite EAGLE-3 head

vLLM's EAGLE-3 path has two whitelist gates that block Granite by default. Both are small patches:

1. Whitelist granite in the EAGLE-3 model-type list (vllm/config/speculative.py):

aux_hidden_states_supported = [
    "llama", "qwen", "minicpm", "gpt_oss", "hunyuan_vl", "hunyuan_v1_dense",
    "afmoe", "nemotron_h", "deepseek_v2", "deepseek_v3", "kimi_k2", "kimi_k25",
    "minimax_m2", "gemma4",
    "granite", # ← added
]

2. Implement the SupportsEagle3 interface on GraniteForCausalLM:

from vllm.model_executor.models.interfaces import EagleModelMixin, SupportsEagle3

class GraniteModel(nn.Module, EagleModelMixin):  # ← mix in
    def forward(self, ...):
        aux_hidden_states: list[torch.Tensor] = []
        aux_hidden_states = self._maybe_add_hidden_state(
            aux_hidden_states, 0, hidden_states, None
        )
        for idx, layer in enumerate(self.layers):
            hidden_states = layer(positions, hidden_states)
            aux_hidden_states = self._maybe_add_hidden_state(
                aux_hidden_states, idx + 1, hidden_states, None
            )
        hidden_states = self.norm(hidden_states)
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
        return hidden_states

class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
    pass  # SupportsEagle3 mix-in handles set_aux_hidden_state_layers

Total: ~30 lines of monkey-patch in a derived Docker image. Both upstream PR candidates.

Sweeping num_speculative_tokens

Single replica, granite-3B BF16, 5000-step head, C=1, 1024 in / 1024 out:

ntok/s C=1TPOT
291.010.7 ms
396.910.0 ms
495.310.2 ms
5104.99.25 ms

n=5 is the peak — exactly the same depth where the Qwen3.6 blog's MTP saturated. Beyond n=5 the draft acceptance falls faster than the speculation gain. Two structural depths, two unrelated draft mechanisms (one trained, one shipped), same answer.

Headline number

Configtok/s C=1TPOT C=1
Granite-3B BF16 baseline92.610.62 ms
Granite-3B + EAGLE-3 (n=5)105.19.25 ms
Δ+13.5%−13%

A clean +13.5% from a head that took 30 minutes of single-GPU compute to produce.

Where EAGLE-3 stops winning

Full four-config matrix (1024 in / 1024 out, n=5):

Granite-3B BF16 — baseline vs EAGLE-3 (n=5) across concurrency Single replica, GPU 1 · 1024 in / 1024 out · sustained tok/s 250 500 750 1000 0 tokens / second 93 105 C = 1 332 310 C = 4 628 475 C = 8 1048 604 C = 16 1047 586 C = 32 baseline EAGLE-3 wins EAGLE-3 regresses (verify-pass overhead)
EAGLE-3 wins decisively at C=1 and regresses sharply at C=8+. Crossover at C≈4. The verify-pass cost can't be amortized once batching saturates compute.

The crossover is at C=4: at lower concurrency EAGLE-3 wins (latency-sensitive single-call workload), at higher concurrency baseline wins (the GPU is already batch-saturated, and the verify-pass + draft-head forward overhead can't be amortized).

This is the canonical EAGLE-3 / MTP / Medusa trade-off, and it has a clean operational answer: deploy two replicas, turn EAGLE-3 ON on one and OFF on the other, route latency-critical single-call requests to the EAGLE replica. Same hardware, both wins, no compromise. vLLM lets you toggle --speculative-config per replica without changing the model path.

Accuracy verification — does any of this hurt quality?

Throughput numbers are the easy part. The pushback we get whenever we publish numbers like the ones above is "you're trading accuracy for speed". So we ran the standard lm-evaluation-harness (the same suite IBM, Anthropic, the OpenLLM Leaderboard, and every model card use) against each of our optimized configs, on GSM8K (1319 grade-school math problems, 5-shot CoT — the metric most sensitive to quantization noise) and IFEval (541 instruction-following prompts — catches "model technically responds but ignores constraints" failures). Each run hits the running vLLM endpoint via --model local-chat-completions, with vLLM's Prometheus /metrics snapshotted before and after so we can read off the actual throughput during eval alongside the accuracy.

GSM8K (1319 problems, 5-shot, strict + flexible exact-match)

Configstrictflexiblewallgen tok/sTPOT
Granite-4.1-3B BF16 (baseline)84.91%86.73%582 s260.213.9 ms
Granite-4.1-3B + EAGLE-3 (n=5)85.60%86.81%621 s244.014.6 ms
Granite-4.1-8B AWQ-INT487.57%89.08%741 s247.414.0 ms
Granite-4.1-30B FP8 (official 8-bit reference)87.72%89.99%3405 s54.564.2 ms
Granite-4.1-30B AWQ-INT4 TP=2 (ours)87.57%89.69%2635 s69.748.9 ms
Granite-4.1-30B AWQ-INT4 cyankiwi (broken)0.00%0.00%3837 s87.836.1 ms

IFEval (541 prompts, 0-shot, prompt-level + instruction-level)

Configprompt-strictprompt-looseinst-strictinst-loosewallgen tok/sTPOT
Granite-4.1-3B BF16 (baseline)78.93%80.22%85.49%86.81%383 s348.311.1 ms
Granite-4.1-3B + EAGLE-3 (n=5)78.37%80.04%85.13%86.57%403 s330.311.6 ms
Granite-4.1-8B AWQ-INT483.18%86.32%88.25%90.53%421 s441.68.8 ms
Granite-4.1-30B FP8 (official 8-bit reference)86.69%88.72%90.65%92.33%1597 s92.042.3 ms
Granite-4.1-30B AWQ-INT4 TP=2 (ours)87.62%89.83%91.37%93.05%1019 s141.827.4 ms

Throughput-during-eval is meaningful here because the eval workload (chat-formatted multi-turn prompts at num_concurrent=4) is much closer to a real agentic load than the synthetic random-token benchmarks earlier in this post. The 8B AWQ holds 441 tok/s effective with sub-9 ms TPOT while serving the IFEval distribution; the 30B at TP=2 holds 142 tok/s effective at 27 ms TPOT.

The headline comparison: our optimization vs the official 8-bit reference

The most direct "what does the optimization actually cost in quality?" comparison is our AWQ-INT4 30B (drawais, sym INT4 g128) against the IBM-published granite-4.1-30b-fp8 checkpoint, on the same eval, same hardware:

MetricFP8 referenceAWQ-INT4 (ours)deltastd error
GSM8K strict87.72%87.57%−0.15pt±0.90pt
GSM8K flex89.99%89.69%−0.30pt±0.83pt
IFEval prompt-strict86.69%87.62%+0.93pt±1.46pt
IFEval prompt-loose88.72%89.83%+1.11pt±1.36pt
Effective gen tok/s (GSM8K)54.569.7+28% faster
Effective gen tok/s (IFEval)92.0141.8+54% faster

Every accuracy delta is under 1.2 percentage points — well inside the test's noise band (the rightmost column shows the ±std-error bound for each benchmark; every delta lands inside it). AWQ-INT4 with our config is 28–54% faster than the official FP8 reference, at zero measurable quality cost. If you were worried that 4-bit quantization on Ampere was costing you measurable quality, the standard public eval suite says it isn't.

Why we ran the negative control

The bottom row of the GSM8K table (granite-4.1-30b-AWQ-INT4 cyankiwi (broken)0.00%) earns its keep by demonstrating something throughput benchmarks can't catch. The cyankiwi quant of the 30B (asymmetric INT4 group_size=32) loads cleanly under TP=2, vLLM reports it as healthy, the engine generates 88 tok/s effective during the eval — by every throughput-side metric, this looks like a working model. The accuracy benchmark catches that the outputs are gibberish ('​​soeverhayhayhay...desk desksdeskdesk\uD emojis emojisemoji'). One run of GSM8K reduces the entire output to a single number — 0.00% — that no synthetic throughput test would ever produce. Run both. Always.

Lossless verification: EAGLE-3 vs baseline

The EAGLE-3 paper proves the trained head is lossless — the rejection-sampling step rejects any drafted token whose probability under the target is lower than under the draft, so the output distribution is identical to running the target alone. That's a mathematical guarantee. Empirical confirmation:

Metricbaseline (3B BF16)+ EAGLE-3 (n=5)deltastd error
GSM8K strict84.91%85.60%+0.68pt±0.99pt
GSM8K flex86.73%86.81%+0.08pt±0.93pt
IFEval prompt-strict78.93%78.37%−0.55pt±1.75pt
IFEval prompt-loose80.22%80.04%−0.18pt±1.71pt

All four deltas land under 1 percentage point of baseline (well inside the per-benchmark noise band shown in the right column) — exactly as the math says they should be. EAGLE-3 buys you single-stream throughput at zero quality cost. Stamp it.

Where the optimization story lands

Reading both tables together: the optimized configs are all in the standard published Granite quality regime. The 8B AWQ-INT4 at 87.57% / 83.18% IFEval prompt-strict, the 30B AWQ-INT4 TP=2 at 87.57% / 87.62% IFEval prompt-strict — these are the numbers IBM reports for their official BF16 / FP8 reference checkpoints, on the same benchmarks, within ~1 percentage point (i.e. well inside the test's noise band — GSM8K's standard error on n=1319 is ±0.9 pt, IFEval's on n=541 is ±1.5 pt). If you were worried that AWQ-INT4 quantization on Ampere was costing you measurable quality on the standard public eval suite, it isn't. You are giving up nothing.

Test configuration: All Granite 4.1 LMs have a native 131,072-token context window. The accuracy evals above used --max-model-len 8192 — sufficient headroom for the GSM8K/IFEval prompt+answer shape. The throughput recipes earlier in this post run at higher caps (8B replicas at 32,768, 30B TP=2 at 16,384, 3B + EAGLE-3 at 16,384). Throughput-during-eval values in this section were measured with each 3090 power-limited to 280W (nvidia-smi -pl 280); accuracy is unaffected, the comparison deltas hold at any power level, and the synthetic benchmark numbers earlier in this post were measured at the default 350W.

Throughput across the context window

The throughput numbers earlier in this post were all at 1024-token prompts. The obvious follow-up: how does single-stream throughput hold up as you push the context window all the way toward Granite's 131,072-token native ceiling?

We swept input lengths 512 → as-far-as-we-could-fit on the same dual-3090 host, one stream at a time (C=1), exactly 128 output tokens per request (min_tokens=max_tokens=128, ignore_eos=true so the decode-TPS measurement is stable across configs). Each request hits /v1/completions directly (raw prompt, no chat template) so the streaming TTFT we measure is the engine's first decoded token. Two metrics per point:

Throughput vs input context — Granite 4.1 stack on dual RTX 3090 C=1 · 128 output tokens · log-x · context = exact prompt_tokens reported by vLLM 256 512 1K 2K 4K 8K 16K 32K 64K 0 50 100 Input context (tokens, log scale) Tokens / sec decode-only TPS (solid) · end-to-end TPS (dashed) Granite 4.1 8B dense (AWQ-INT4) — up to 90K Granite 4.1 30B hybrid (AWQ-INT4, TP=2) — up to 101K Granite 4.1 3B BF16 + EAGLE-3 (n=5) — up to 116K C=1 · 128 output tokens (ignore_eos=true) · /v1/completions raw · context = vLLM-reported prompt_tokens

Four observations:

  1. The 8B decode penalty is real but the absolute numbers stay usable. 117 → 31 tok/s over a 200× context expansion (455 → 92,736 input). At 30K input — the long-context regime most agentic workflows actually hit — the 8B is still doing 64 tok/s of decode. End-to-end TPS, dominated by prefill, sits at 9 tok/s at 30K; whatever you're doing at long context, the TTFT is the user-visible cost, not decode.
  2. The 30B hybrid loses ~75% of decode-TPS over the sweep. 57 → 14 tok/s from 460 to 103,908 input. We went into this expecting the Mamba/SSM layers to neutralize KV-cache pressure (constant-state SSM means no per-token memory growth on those layers). The attention layers in the hybrid still scale, and at 100K input the residual KV pressure is enough to halve decode again past the 60K point. Hybrid helps; it does not exempt.
  3. The 3B + EAGLE-3 head collapses fastest. 80 → 9 tok/s from 459 to 118,809 input. EAGLE-3's draft-model acceptance rate falls as context grows — the draft can't predict three steps ahead as accurately when the conditioning is 100K tokens long — and the per-spec-step rejection cost compounds. EAGLE-3 wins big at 1–4K input (the chart's left edge crosses 80–115 tok/s, well above the 8B at the same context); past 30K it becomes a liability vs. running the same 3B base model without spec decode.
  4. Each model has a different "long context cliff." The 8B's decode TPS halves between 30K and 90K. The 30B halves between 60K and 100K. The 3B+EAGLE-3 halves between 15K and 30K. If you're building an agent that lives at one specific context length, pick the model whose cliff is on the other side of your operating point — the difference is 3–5× per-stream throughput.

All three Granite lines extend through ~70–90% of the 131K native context. The remaining headroom (the gap between our max measured input and 131,072) is a vLLM KV-cache memory-overhead issue, not an architectural one — at --gpu-memory-utilization 0.95 the engine reserves slightly more than the available KV space at exactly 131K, and dropping to ~115K-120K leaves enough margin to launch.

Bench configuration: dual RTX 3090 (Ampere SM 8.6, 48 GB total), 280W power limit, vLLM 0.19.2 nightly, /v1/completions raw, temperature=0, min_tokens=max_tokens=128, ignore_eos=true, per-request nonce defeats prefix caching. Each x value is the exact prompt_tokens reported by vLLM (after tokenization). For long-context coverage we used single-replica launchers with --max-num-seqs 1 to free VRAM for KV (Granite 8B at --max-model-len 102400, 30B at 114,688, 3B+EAGLE-3 at 131,072). The lower-context end of each line uses the production launchers from earlier in this post — there's no visible discontinuity because Granite's same configurations are KV-friendly enough that single-replica vs LB matters less per-stream than it does on the Qwen stack.

Reproducible recipes

Recipe 1 — Granite-4.1-8B 2 replicas + nginx LB (1222 tok/s aggregate)

# Stack components — only need the AWQ-INT4 checkpoint
hf download cyankiwi/granite-4.1-8b-AWQ-INT4         # ~5.4 GB
docker pull vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08

# Launch one replica per GPU
for N in 0 1; do
  docker run -d --name granite-8b-$N --gpus "\"device=$N\"" \
    -v ~/.cache/huggingface:/root/.cache/huggingface \
    -p 860$N:8000 --ipc=host --shm-size=16gb \
    -e VLLM_NO_USAGE_STATS=1 -e HF_HUB_OFFLINE=1 \
    -e VLLM_USE_FLASHINFER_SAMPLER=1 \
    vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08 \
    --model cyankiwi/granite-4.1-8b-AWQ-INT4 --served-model-name granite-8b \
    --tensor-parallel-size 1 --max-model-len 32768 \
    --gpu-memory-utilization 0.92 \
    --max-num-seqs 32 --max-num-batched-tokens 4096 \
    --enable-prefix-caching --enable-chunked-prefill
done

# nginx LB — least_conn, no buffering for SSE streaming
cat > nginx.conf <<'EOF'
events { worker_connections 4096; }
http {
    upstream vllm_pool {
        least_conn;
        server 127.0.0.1:8600 max_fails=3 fail_timeout=10s;
        server 127.0.0.1:8601 max_fails=3 fail_timeout=10s;
    }
    proxy_read_timeout 900s; proxy_buffering off; proxy_request_buffering off;
    server {
        listen 8400;
        location / { proxy_pass http://vllm_pool; proxy_http_version 1.1; proxy_set_header Connection ""; }
    }
}
EOF
docker run -d --name granite-lb --network host \
  -v $(pwd)/nginx.conf:/etc/nginx/nginx.conf:ro nginx:alpine

Hit http://localhost:8400/v1/chat/completions with up to 64 concurrent requests for max aggregate throughput.

Recipe 2 — Granite-4.1-30B TP=2 (216 tok/s aggregate at C=8)

hf download drawais/Granite-4.1-30B-AWQ-INT4         # ~15.5 GB

docker run -d --name granite-30b --gpus all \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -p 8500:8000 --ipc=host --shm-size=16gb \
  -e VLLM_NO_USAGE_STATS=1 -e HF_HUB_OFFLINE=1 \
  -e VLLM_USE_FLASHINFER_SAMPLER=1 -e VLLM_MARLIN_USE_ATOMIC_ADD=1 \
  -e VLLM_WORKER_MULTIPROC_METHOD=spawn -e NCCL_P2P_DISABLE=1 \
  vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08 \
  --model drawais/Granite-4.1-30B-AWQ-INT4 --served-model-name granite-30b \
  --tensor-parallel-size 2 --max-model-len 16384 \
  --gpu-memory-utilization 0.92 \
  --max-num-seqs 8 --max-num-batched-tokens 4096 \
  --enable-prefix-caching --enable-chunked-prefill

Do not use cyankiwi/granite-4.1-30b-AWQ-INT4 — its asymmetric INT4 group_size=32 quant produces gibberish on Ampere SM 8.6 (Marlin INT4 kernel correctness issue specific to that quant-parameter combination).

Recipe 3 — Train and serve an EAGLE-3 head for Granite-4.1-3B

Two adapted images:

Train (30 min on a single 3090):

docker run -d --name granite-eagle3-train --gpus '"device=0"' \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -v $(pwd)/cache:/workspace/cache -v $(pwd)/outputs:/workspace/outputs \
  -e HF_HUB_OFFLINE=1 -e WANDB_MODE=disabled --shm-size=16gb --ipc=host \
  granite-eagle3-trainer:v3 \
  bash -c "cd /workspace/SpecForge && torchrun --standalone --nproc_per_node 1 \
    scripts/train_eagle3.py \
    --target-model-path ibm-granite/granite-4.1-3b \
    --draft-model-config configs/granite-4.1-3b-eagle3.json \
    --train-data-path /workspace/cache/dataset/sharegpt_train.jsonl \
    --output-dir /workspace/outputs/granite-3b-eagle3 \
    --num-epochs 2 --max-num-steps 5000 \
    --batch-size 1 --max-length 2048 \
    --target-model-backend hf \
    --learning-rate 1e-4 --warmup-ratio 0.03 \
    --chat-template granite --attention-backend sdpa \
    --log-interval 25 --save-interval 1000"

Serve (the head dir needs the granite-3b tokenizer copied into it):

HEAD=$(pwd)/outputs/granite-3b-eagle3/epoch_0_step_5000
docker run -d --name granite-3b-eagle --gpus '"device=0"' \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -v "$HEAD":/eagle-head:ro \
  -p 8600:8000 --ipc=host --shm-size=16gb \
  -e VLLM_NO_USAGE_STATS=1 -e HF_HUB_OFFLINE=1 \
  -e VLLM_USE_FLASHINFER_SAMPLER=1 \
  vllm-eagle3-granite:v2 \
  --model ibm-granite/granite-4.1-3b --served-model-name granite-3b-eagle \
  --tensor-parallel-size 1 --max-model-len 16384 \
  --gpu-memory-utilization 0.85 \
  --max-num-seqs 16 --max-num-batched-tokens 4096 \
  --enable-prefix-caching --enable-chunked-prefill \
  --speculative-config '{"method":"eagle3","model":"/eagle-head","num_speculative_tokens":5}'

Will this work on RTX 4090 / 5090?

GPUSMVRAMRecommendation
RTX 3090Ampere 8.624 GBThis blog post. AWQ-INT4 only — FP8 path is broken on Ampere.
RTX 4090Ada 8.924 GBSame recipe, expect ~25–40% higher tok/s. FP8 path still software-emulated, still slower than INT4.
RTX 5090Blackwell 12.032 GBNative FP8/FP4 changes the calculus completely — the official IBM granite-4.1-{8b,30b}-fp8 checkpoints become first-class. The 8 GB extra VRAM lets you raise max-num-seqs further. The 30B at FP8 (~30 GB) fits a single 5090 with room for KV — unlocking the 2-replica + LB pattern that loses on dual 3090s for that size.

The Granite recipe is hardware-agnostic in shape — only the quantization choice changes. On Ampere: AWQ-INT4. On Blackwell: official FP8. On Hopper / Ada datacenter: either works.

Lessons learned

  1. Architecture simplicity is an optimization. Granite 4.1 LMs are dense GQA transformers with no Mamba/SSM/hybrid layers. That sounds boring, but it means TurboQuant, AWQ, and all the standard kernels work on every layer. Our Qwen3.6 work spent half its time on Genesis patches that exist only because the Qwen architecture's hybrid attention layers break standard KV-cache and MTP code paths. None of that complexity exists for Granite.
  2. Don't trust an INT4 checkpoint until you read its output. cyankiwi/granite-4.1-30b-AWQ-INT4 boots cleanly and produces fluent-looking gibberish. The Marlin INT4 kernel on Ampere has multiple code paths for different quant parameters (sym/asym, group_size 32/128) and not all of them are bulletproof for every layer-shape combination. Symmetric INT4 group_size=128 is the safer default on this hardware.
  3. TP=2 vs 2-replica LB depends on weight bytes. The Qwen-blog finding "TP=2 always loses on PCIe 3090" was an artifact of measuring at one weight size. For Granite-30B, weights consume too much per-card VRAM for max-num-seqs > 4 in 2-replica mode, and TP=2's ability to free batch headroom outweighs the NCCL tax. The decision rule: when weights >> 0.5× VRAM, TP=2 wins; when weights << 0.5× VRAM, 2 replicas + LB wins.
  4. FP8 weights on Ampere are a footgun, not a feature. vLLM's Marlin FP8 weight-only kernel falls back to a software dequant on Ampere and quietly produces wrong outputs for some architectures. The granite-4.1-8B-fp8 case generates whitespace; the 30B-fp8 happens to produce coherent text but at 45–80% lower throughput than the AWQ-INT4 path. Use AWQ-INT4 on Ampere. Period.
  5. You can train a useful EAGLE-3 head for under an hour on consumer hardware. 5000 steps at 1024-context, 5k ShareGPT samples, a single RTX 3090 → 30 minutes wall-clock → +13.5% single-stream throughput on the target. The published EAGLE-3 numbers (1.5–2× speedup at 0.4–0.6 acceptance) are achievable with 50k samples and 10 epochs (an overnight run on the same hardware) — the floor we measured here is the start of that curve, not its end.
  6. Speculative decoding has a sharp single-stream/batch trade-off. Across n=2..5 sweeps and across single-replica vs 2-replica LB topologies, EAGLE-3 wins decisively at C=1 and regresses sharply at C=8+. The crossover is around C=4. The clean operational answer is to deploy two replicas with EAGLE-3 ON on one and OFF on the other, and route latency-critical traffic to the EAGLE replica. Same hardware, both wins.
  7. max-num-seqs is a surprisingly load-bearing flag. Going from max-num-seqs=8 to max-num-seqs=32 on a single 8B replica nearly doubles aggregate throughput at C=32 (676 → 1052 tok/s) without any other change. The standard default of 8 leaves a lot on the table for small models on big GPUs.

Appendix — full software/config diff

Driver
580.159.03 (CUDA 13.0 capable)
vLLM image
vllm/vllm-openai:nightly-07351e0883470724dd5a7e9730ed10e01fc99d08 (vllm 0.19.2rc1.dev205)
8B model
cyankiwi/granite-4.1-8b-AWQ-INT4 (5.4 GB)
30B model
drawais/Granite-4.1-30B-AWQ-INT4 (15.5 GB, sym INT4 g128)
3B target (EAGLE-3)
ibm-granite/granite-4.1-3b (BF16, ~6 GB)
EAGLE-3 trainer
SpecForge HEAD with granite chat template + lazy sglang imports
EAGLE-3 server
vLLM nightly with granite EAGLE3 whitelist + EagleModelMixin patch
KV cache
auto (BF16) — fp8_e5m2 rejected on compressed-tensors
Tensor parallel
TP=1 + 2 replicas + LB (8B), TP=2 (30B), TP=1 single-replica (3B + EAGLE-3)
Speculative decoding
none (8B production), EAGLE-3 n=5 trained ourselves (3B / latency endpoint)

Headline numbers, side-by-side

StackHardwaretok/s C=1tok/s C=4tok/s peakTPOT C=1
Qwen3.6-27B + 2× LB + MTP=52× 30901002252259.7 ms
Qwen3.6-35B-A3B MoE + TP=2+EP2× 3090902832839.95 ms
Granite-4.1-8B AWQ-INT4 + 2× LB2× 309010530726249.2 ms
Granite-4.1-30B AWQ-INT4 + TP=22× 30904413821621.6 ms
Granite-4.1-3B + EAGLE-3 (trained)1× 30901053093119.25 ms

For an agentic system fanning out 2–8 tool calls, the granite-8B 2-replica stack is the answer. For absolute single-stream latency on a smaller model, the granite-3B + EAGLE-3 is the answer. For maximum-quality occasional generations, the granite-30B TP=2 is the answer. Three different shapes, three different topologies, all on the same dual-3090 box.

Credits & sources