Add static-shape wrapper for TensorRT export (4.8x speedup)#520
Add static-shape wrapper for TensorRT export (4.8x speedup)#5202imi9 wants to merge 2 commits intoallenai:mainfrom
Conversation
StaticOlmoEarthEncoder wraps the FlexiViT encoder with fixed shapes, enabling torch.export() and TensorRT compilation. References the same trained weights — no copying or retraining. Results on OlmoEarth-v1-Base (RTX 5090, bs=4): PyTorch eager FP32: 166.9ms (1.0x) TensorRT FP16: 34.6ms (4.8x), cosine sim 0.999999 Files: - olmoearth_pretrain/export.py: StaticOlmoEarthEncoder + export pipeline - scripts/benchmark_trt_export.py: TRT benchmark script - tests/unit/test_export.py: 11 unit tests (all passing) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6c90296854
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue | ||
| from olmoearth_pretrain.export import StaticOlmoEarthEncoder, verify_export | ||
| from olmoearth_pretrain.model_loader import ModelID, load_model_from_id | ||
| from olmoearth_pretrain.quantization import check_modelopt_available, quantize_model |
There was a problem hiding this comment.
Remove import of nonexistent quantization module
benchmark_trt_export.py imports olmoearth_pretrain.quantization at module load time, but this repo does not contain that module (repo-wide search for quantize_model/check_modelopt_available only finds these new call sites). As a result, running the script raises ImportError before argument parsing, so even --precision fp32 --skip-trt cannot execute.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed — now uses try/except with graceful fallback when quantization module isn't installed.
| torch.manual_seed(i + 42) | ||
| # Create input tensor | ||
| x = torch.randn(1, spatial, spatial, T, S2_MODALITY.num_bands, device=device) | ||
| ts = torch.tensor([[[1, 6, 2020]]], dtype=torch.long, device=device) |
There was a problem hiding this comment.
Construct timestamps with the configured timestep count
verify_export hardcodes ts to shape [1, 1, 3] while x and masks use T = static_encoder.num_timesteps; when num_timesteps > 1 (supported by StaticOlmoEarthEncoder and exposed by export_to_tensorrt), static_encoder(x, ts) fails because temporal encodings expect T timestamps. This breaks verification/export for multi-temporal inputs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed — now uses num_timesteps from the static encoder config.
…ount - benchmark_trt_export.py: graceful fallback when quantization module is not installed (it lives in PR allenai#516) - export.py: verify_export uses num_timesteps instead of hardcoded T=1 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Static-shape wrapper for OlmoEarth's FlexiViT encoder, enabling torch.export() and TensorRT compilation.
FlexiViT's dynamic shapes (boolean indexing, data-dependent slicing, ndim branching) block torch.export/TensorRT. StaticOlmoEarthEncoder replays the encoder's forward pass with all shapes baked in at construction time, referencing the same trained weights — no retraining.
Results on OlmoEarth-v1-Base (RTX 5090, bs=4):
Files:
olmoearth_pretrain/export.py: StaticOlmoEarthEncoder + export pipelinescripts/benchmark_trt_export.py: TRT benchmark scripttests/unit/test_export.py: 11 unit tests (all passing)Related: #516 (quantization accuracy evaluation)