Add FP4/FP8 weight quantization for Blackwell/Hopper GPU inference#516
Add FP4/FP8 weight quantization for Blackwell/Hopper GPU inference#5162imi9 wants to merge 6 commits intoallenai:mainfrom
Conversation
nvidia-modelopt based weight quantization module and benchmark scripts
for OlmoEarth ViT models. Supports FP4 (Blackwell) and FP8 (Hopper+).
Results on OlmoEarth-v1-Nano (1.36M params, RTX 5090):
EuroSAT KNN classification (real Sentinel-2 imagery, 27K samples):
FP32: 65.6% accuracy (baseline)
FP8: 65.4% accuracy (-0.2%)
FP4: 63.6% accuracy (-2.0%)
Embedding cosine similarity vs FP32:
FP8: 0.999 mean | FP4: 0.980 mean
Quantized models: huggingface.co/2imi9/olmoearth-nano-fp8, olmoearth-nano-fp4
W&B: wandb.ai/2imi9-northeastern-university/OlmoEarth_Q
Files:
- olmoearth_pretrain/quantization.py: reusable quantization module
- scripts/nvfp4_quantization.py: quantization + cosine similarity pipeline
- scripts/eval_quantization.py: real-data KNN evaluation on EuroSAT
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e16edacdde
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| fp32_model.eval() | ||
| # Deep copy for FP4 so we keep the original FP32 model | ||
| fp4_model = copy.deepcopy(fp32_model) | ||
| fp4_model = step2_quantize(fp4_model, args.quant_config, precision=args.precision) |
There was a problem hiding this comment.
Gate FP4/FP8 comparisons on quantization success
step2_quantize() returns the input model unchanged when quantization is skipped or fails, and the caller still stores that object in fp4_model after a deep copy; this causes later steps to benchmark and compare an unquantized FP32 clone as if it were FP4/FP8, which can silently invalidate reported quality and throughput results in environments without ModelOpt/CUDA/nvcc.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in b0d56a7 — now checks count_quantizer_nodes() after quantization and falls back to FP32 if no nodes were inserted.
| sim = test_emb @ train_emb.t() | ||
|
|
||
| # Top-k | ||
| topk_sim, topk_idx = sim.topk(k, dim=1) |
There was a problem hiding this comment.
Bound KNN top-k by training set size
The KNN path always executes sim.topk(k, dim=1) with k=20, so runs with fewer than 20 training embeddings (for example --max-train 8) crash with an out-of-range error instead of producing metrics; using min(k, n_train) avoids this hard failure for smaller/debug subsets.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in b0d56a7 — added k = min(k, len(train_emb)).
- Check quantizer node count after step2_quantize to detect failed quantization instead of silently benchmarking an FP32 clone as FP4 - Bound KNN k by training set size to prevent crash with small subsets
|
Complementary to #477 — that PR quantizes output embeddings for storage, this PR quantizes model weights for inference. |
|
Note: Both FP8 and FP4 use simulated quantization (real precision loss, FP32 compute). Native inference speedup requires TensorRT export, currently blocked by FlexiViT's dynamic shapes. The quantization module and accuracy results are ready for when export support matures. |
|
i tried TensorRT export with dynamo and torchscript, both fail due to FlexiViT's dynamic shapes. Would a static-shape export path be worth exploring, or is the accuracy validation sufficient for now? |
Tests for count_quantizable_layers, get_model_memory_mb, count_quantizer_nodes, _get_quant_config, and availability checks. 17 tests covering all public functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Hi, thanks for your interest in OlmoEarth! Based on the accuracy numbers you report it seems like you are loading random weights. vit b is expected to get ~94-95 on knn. Also, it would be great if you could share the latency under different settings for the quantization and screenshots of wandb as we don't have access to the link you shared. |
…retrained normalization
|
Thanks for catching that — you were right. Two issues: I tested on NANO instead of BASE, and my EuroSAT loader used RGB zero-padded to 12 bands instead of real Sentinel-2 data. Fixed in f617aa1 — now loads EuroSAT multispectral (13 S2 bands .tif), maps to OlmoEarth's 12-band order, normalizes with pretrained computed stats. Updated results on OlmoEarth-v1-Base (RTX 5090):
Latency is higher for quantized because these are fake-quantized (simulated precision loss, FP32 compute). Native speedup requires TensorRT export. |
StaticOlmoEarthEncoder wraps the FlexiViT encoder with fixed shapes, enabling torch.export() and TensorRT compilation. References the same trained weights — no copying or retraining. Results on OlmoEarth-v1-Base (RTX 5090, bs=4): PyTorch eager FP32: 166.9ms (1.0x) TensorRT FP16: 34.6ms (4.8x), cosine sim 0.999999 Files: - olmoearth_pretrain/export.py: StaticOlmoEarthEncoder + export pipeline - scripts/benchmark_trt_export.py: TRT benchmark script - tests/unit/test_export.py: 11 unit tests (all passing) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This reverts commit 84efff3.
i also tried to Solve the TensorRT export issue in #520. Static-shape wrapper, 4.8x speedup with TRT FP16. |
…ount - benchmark_trt_export.py: graceful fallback when quantization module is not installed (it lives in PR allenai#516) - export.py: verify_export uses num_timesteps instead of hardcoded T=1 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>


nvidia-modelopt weight quantization module and benchmark scripts for OlmoEarth ViT. Supports FP4 (Blackwell) and FP8 (Hopper+).
Complementary to #477 — that PR quantizes output embeddings for storage, this PR quantizes model weights for inference.
Results on OlmoEarth-v1-Base (86M params, RTX 5090):
EuroSAT KNN classification (Sentinel-2 multispectral, 27K samples, pretrained normalization):
Note: FP8/FP4 use simulated quantization (real precision loss, FP32 compute). Latency is higher due to quantize-dequantize overhead. Native inference speedup requires TensorRT export, blocked by FlexiViT's dynamic shapes.
Files:
olmoearth_pretrain/quantization.py: reusable quantization modulescripts/nvfp4_quantization.py: quantization + cosine similarity pipelinescripts/eval_quantization.py: EuroSAT multispectral KNN evaluation