Added benchmark for LLaMA 3 model for attention tests#3930
Added benchmark for LLaMA 3 model for attention tests#3930howardzhang-cv merged 34 commits intomainfrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3930
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b6f072f with merge base 42bcdc4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # ============================================================================= | ||
|
|
||
|
|
||
| def load_wikitext2_tokens(tokenizer, seq_len: int): |
There was a problem hiding this comment.
We don't have to define wiki-text tokenizer here. Instead, we can run lm-eval directly like
There was a problem hiding this comment.
Thanks! I changed it to use this instead
| return chunks | ||
|
|
||
|
|
||
| def compute_perplexity(model, chunks, device: str, backend_name: str) -> float: |
There was a problem hiding this comment.
This is also not needed. See above comment
| return math.exp(avg_loss) | ||
|
|
||
|
|
||
| def benchmark_runtime( |
There was a problem hiding this comment.
Q. Can we compute forward pass latency using vLLM directly, similar to e2e?
There was a problem hiding this comment.
Unfortunately not. Unlike the other quantization APIs in TorchAO, the low precision attention path requires replacing F.scaled_dot_product_attention with a specific attention backend capable of low precision attention (e.g. FA3/4). So we need to ensure that our model calls F.SDPA.
| } | ||
|
|
||
| RANDOM_SEED = 42 | ||
| DEFAULT_MODEL_ID = "meta-llama/Llama-3.1-8B" |
There was a problem hiding this comment.
What does DEFAULT_MODEL_ID do? Should it be called by args (--model_id) default directly?
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 528c2ec Pull-Request: pytorch#3930
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 528c2ec Pull-Request: pytorch#3930
[ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 240c54f Pull-Request: pytorch#3930
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 240c54f Pull-Request: pytorch#3930
[ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c635ea3 Pull-Request: pytorch#3930
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c635ea3 Pull-Request: pytorch#3930
[ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e397044 Pull-Request: pytorch#3930
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e397044 Pull-Request: pytorch#3930
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: 859f523 Pull-Request: pytorch#3930
[ghstack-poisoned]
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: b24335c Pull-Request: pytorch#3930
[ghstack-poisoned]
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: c3386ef Pull-Request: pytorch#3930
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: c3386ef Pull-Request: pytorch#3930
namgyu-youn
left a comment
There was a problem hiding this comment.
LGTM, thanks for addressing all the comments!
[ghstack-poisoned]
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: c3386ef Pull-Request: pytorch#3930
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: c3386ef Pull-Request: pytorch#3930
Benchmark script for evaluating FP8 attention on LLaMA 3 models. Measures perplexity on WikiText-2 and runtime performance across sequence lengths with and without RoPE fusion. ghstack-source-id: c3386ef Pull-Request: pytorch#3930
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary
Example Run
python benchmarks/prototype/attention/eval_llama3_model.py --baseline fa3 --test fa3_fp8