Skip to content

Add static-shape wrapper for TensorRT export (4.8x speedup)#520

Open
2imi9 wants to merge 2 commits intoallenai:mainfrom
2imi9:trt-export
Open

Add static-shape wrapper for TensorRT export (4.8x speedup)#520
2imi9 wants to merge 2 commits intoallenai:mainfrom
2imi9:trt-export

Conversation

@2imi9
Copy link
Copy Markdown

@2imi9 2imi9 commented Mar 31, 2026

Static-shape wrapper for OlmoEarth's FlexiViT encoder, enabling torch.export() and TensorRT compilation.

FlexiViT's dynamic shapes (boolean indexing, data-dependent slicing, ndim branching) block torch.export/TensorRT. StaticOlmoEarthEncoder replays the encoder's forward pass with all shapes baked in at construction time, referencing the same trained weights — no retraining.

Results on OlmoEarth-v1-Base (RTX 5090, bs=4):

Method Latency Speedup Cosine Sim
PyTorch eager FP32 166.9ms 1.0x
TensorRT FP16 34.6ms 4.8x 0.999999

Files:

  • olmoearth_pretrain/export.py: StaticOlmoEarthEncoder + export pipeline
  • scripts/benchmark_trt_export.py: TRT benchmark script
  • tests/unit/test_export.py: 11 unit tests (all passing)

Related: #516 (quantization accuracy evaluation)

StaticOlmoEarthEncoder wraps the FlexiViT encoder with fixed shapes,
enabling torch.export() and TensorRT compilation. References the same
trained weights — no copying or retraining.

Results on OlmoEarth-v1-Base (RTX 5090, bs=4):
  PyTorch eager FP32: 166.9ms (1.0x)
  TensorRT FP16:       34.6ms (4.8x), cosine sim 0.999999

Files:
- olmoearth_pretrain/export.py: StaticOlmoEarthEncoder + export pipeline
- scripts/benchmark_trt_export.py: TRT benchmark script
- tests/unit/test_export.py: 11 unit tests (all passing)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6c90296854

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread scripts/benchmark_trt_export.py Outdated
from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
from olmoearth_pretrain.export import StaticOlmoEarthEncoder, verify_export
from olmoearth_pretrain.model_loader import ModelID, load_model_from_id
from olmoearth_pretrain.quantization import check_modelopt_available, quantize_model
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Remove import of nonexistent quantization module

benchmark_trt_export.py imports olmoearth_pretrain.quantization at module load time, but this repo does not contain that module (repo-wide search for quantize_model/check_modelopt_available only finds these new call sites). As a result, running the script raises ImportError before argument parsing, so even --precision fp32 --skip-trt cannot execute.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — now uses try/except with graceful fallback when quantization module isn't installed.

Comment thread olmoearth_pretrain/export.py Outdated
torch.manual_seed(i + 42)
# Create input tensor
x = torch.randn(1, spatial, spatial, T, S2_MODALITY.num_bands, device=device)
ts = torch.tensor([[[1, 6, 2020]]], dtype=torch.long, device=device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Construct timestamps with the configured timestep count

verify_export hardcodes ts to shape [1, 1, 3] while x and masks use T = static_encoder.num_timesteps; when num_timesteps > 1 (supported by StaticOlmoEarthEncoder and exposed by export_to_tensorrt), static_encoder(x, ts) fails because temporal encodings expect T timestamps. This breaks verification/export for multi-temporal inputs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — now uses num_timesteps from the static encoder config.

…ount

- benchmark_trt_export.py: graceful fallback when quantization module
  is not installed (it lives in PR allenai#516)
- export.py: verify_export uses num_timesteps instead of hardcoded T=1

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant