Skip to content

W4A16-INT (TinyGEMM+HQQ) is 6-13x slower than FP16 for batched inference #3496

@namgyu-youn

Description

@namgyu-youn

Overview

One of W4A16-INT api shows poor throughput in the vLLM benchmark.

  • Device: NVIDIA A100 80GB
  • Base Model: Qwen/Qwen3-8B
  • Config: Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")

This issue is observed by using the TorchAO model release script, https://github.com/pytorch/ao/blob/main/.github/scripts/torchao_model_releases/quantize_and_upload.py.

# Release command
# Checkpoint: https://huggingface.co/namgyu-youn/Qwen3-8B-INT4
python .github/scripts/torchao_model_releases/quantize_and_upload.py   --model_id Qwen/Qwen3-8B   --quant INT4   --push_to_hub   --push_to_user_id namgyu-youn   --populate_model_card_template

. In the vLLM benchmark, W4A16-INT showed 1275.87 total tokens/s and 255.17 output tokens/s, while original (FP16) model showed 10091.15 total tokens/s, 2018.23 output tokens/s.


Experiment: Low throughput in vLLM

After the model generation, tested perf using vLLM (vllm serve), and got poor throughput:

vllm bench throughput \  --model namgyu-youn/Qwen3-8B-INT4 \  --input-len 512 --output-len 128 \  --num-prompts 100

When dataset path is not set, it will default to random dataset
tokenizer_config.json: 5.40kB [00:00, 12.2MB/s]
vocab.json: 2.78MB [00:00, 10.5MB/s]
merges.txt: 1.67MB [00:00, 6.43MB/s]
(...)
Throughput: 1.99 requests/s, 1275.87 total tokens/s, 255.17 output tokens/s

. For the comparison, original model throughput is:

vllm bench throughput \  --model Qwen/Qwen3-8B \  --input-len 512 --output-len 128 \  --num-prompts 100
(...)
Throughput: 15.77 requests/s, 10091.15 total tokens/s, 2018.23 output tokens/s

Debugging 01: Check if kernels (TinyGEMM) run correctly

At first, I tried to check if incorrect kernels are called while vLLM, so tried profiling for kernel check:

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "namgyu-youn/Qwen3-8B-INT4",
    device_map="cuda:0",
    torch_dtype=torch.bfloat16
)

x = torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda")

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
    with torch.no_grad():
        out = model.model.layers[0].self_attn.q_proj(x)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

And the result is:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void at::native::tinygemm_m16n8k16_chunk_kernel<at::...         0.00%       0.000us         0.00%       0.000us       0.000us     925.791us        99.36%     925.791us     925.791us             1  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       5.984us         0.64%       5.984us       5.984us             1  
                                        cudaMemcpyAsync         3.52%      62.763us        24.96%     445.328us     445.328us       0.000us         0.00%       0.000us       0.000us             1  
                                Activity Buffer Request        21.44%     382.565us        21.44%     382.565us     382.565us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel         2.17%      38.782us        37.85%     675.376us     675.376us       0.000us         0.00%       0.000us       0.000us             1  
                       Runtime Triggered Module Loading        34.01%     606.924us        34.01%     606.924us     303.462us       0.000us         0.00%       0.000us       0.000us             2  
                                  Lazy Function Loading         1.66%      29.670us         1.66%      29.670us      29.670us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaFuncGetAttributes         1.13%      20.125us         1.13%      20.125us      20.125us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize        36.06%     643.476us        36.06%     643.476us     643.476us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.784ms
Self CUDA time total: 931.775us

, showing TinyGEMM is called (no issue here)


Debugging 02: Check with small batch size

Wondering if M (computation intensity) matters for throughput, so tested with low M:

vllm bench throughput \
  --model namgyu-youn/Qwen3-8B-INT4 \
  --input-len 1 \
  --output-len 512 \
  --num-prompts 100

And the result is:

Throughput: 2.21 requests/s, 1134.09 total tokens/s, 1131.88 output tokens/s
Total num prompt tokens:  100
Total num output tokens:  51200

. This shows TinyGEMM shows poor throughput in vLLM (continuous batching) regardless of M.


Debugging 03: TinyGEMM kernel performance vs batch size (M)

To find the reason, tested with variable SeqLen and batch and got raw kernel performance:

# _weight_int4pack_mm vs FP16 matmul at different M
| M      | INT4    | FP16   | Slowdown |
|--------|---------|--------|----------|
| 1      | 0.026ms | 0.070ms| 0.37x|
| 100    | 0.554ms | 0.082ms| 6.7x     |
| 512    | 2.773ms | 0.212ms| 13x      |

tinygemm_m16n8k16_chunk_kernel is only efficient for M=1. For M>1, it's 6-13x slower than FP16.

Answer by Claude (not sure yet): vLLM uses continuous batching, which makes M>1 always during both prefill and decode, making TinyGEMM inefficient.


Conclusion

Even though ppl is quiet well, poor throughput showed these configs are not right for real serving. Should I

  1. Support HQQ for other layouts like Int4MarlinSparseTensor
  2. Document TinyGEMM is optimized for single-batch inference only
  3. or something

?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions