Quantization using distribution of embeddings on pre-training dataset#477
Quantization using distribution of embeddings on pre-training dataset#477
Conversation
Since it generates the HDF5 file containing the quantiles in the format expected by the load_quantiles_config function.
| return config | ||
|
|
||
|
|
||
| def quantize_embeddings_percentile( |
There was a problem hiding this comment.
Did this percentile-based bucketing approach behave substantially different than the statistics-naive quantization scheme in https://github.com/allenai/olmoearth_run/blob/006496243c8f00ada3b74a77874e87a93bfa661e/src/olmoearth_run/runner/tools/postprocessors/combine_geotiff.py#L48?
I magine it's probably pretty important with the very low-bit quantizations, but do you have a sense at int8?
There was a problem hiding this comment.
AFAIK neither approach showed any drop in performance at int8, here is Mike's results:
https://github.com/allenai/olmoearth_pretrain/blob/main/scripts/archived/2026-01-024_embedding_analysis/quant_comparison_rounded.csv
In the platform I think you should just go with the simple fixed quantization.
There was a problem hiding this comment.
also I computed the per-band distribution here, but I found that all of the bands follow almost the same distribution
There was a problem hiding this comment.
Normal distribution centered over 0.0? That would be convenient for int1 😅
There was a problem hiding this comment.
Guess it doesn't need to be a normal distribution when there are only two quantiles.
yawenzzzz
left a comment
There was a problem hiding this comment.
LGTM! just a small question about the int8 vs. unit8
| - "quantiles": torch.Tensor of shape (dim, num_buckets+1) | ||
| - "midpoints": torch.Tensor of shape (dim, num_buckets) |
There was a problem hiding this comment.
Suggestion: add a note say that quantiles are for quantization and midpoints are for dequantization
|
|
||
| # Flatten to (N_total, dim) | ||
| # Convert to uint8 first to handle int8 wrap-around (128-255 stored as -128 to -1) | ||
| flat = quantized.reshape(-1, dim).to(torch.uint8).long() |
There was a problem hiding this comment.
I'm wondering why don't we just quantize to unit8, with this, there's no need to convert to uint8 in the dequantization

Try quantizing to 8/4/2/1-bit using distribution of embeddings on pre-training dataset.