Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 100 additions & 24 deletions serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,41 @@ import (

const maxConfigSize = 1 << 20 // 1MB max config size

const (
// maxRGSetSize limits roaring bitmap deserialization.
// 16MB covers worst-case bitmaps for millions of row groups.
maxRGSetSize = 16 << 20

// maxNumPaths matches PathID's uint16 range.
maxNumPaths = 65535

// maxTermsPerPath caps string index terms per path.
// Default CardinalityThreshold is 10,000; 1M is generous headroom.
maxTermsPerPath = 1_000_000

// maxTrigramsPerPath caps trigram entries per path.
// ASCII ceiling is ~2M (128^3); 10M covers extreme Unicode FTS.
maxTrigramsPerPath = 10_000_000

// maxBloomWords caps bloom filter word count.
// Default is 65536 bits (1024 words). 1M words (~8MB) is generous.
maxBloomWords = 1 << 20

// maxHLLRegisters caps HyperLogLog register count.
// Max precision 16 needs 2^16 = 65536 registers.
maxHLLRegisters = 1 << 16
)

var (
// ErrVersionMismatch is returned by Decode when the binary format version
// does not match the expected version (Version constant).
ErrVersionMismatch = errors.New("version mismatch")

// ErrInvalidFormat is returned by Decode when the binary data is structurally
// invalid: unrecognized magic bytes, oversized allocations, or corrupt fields.
ErrInvalidFormat = errors.New("invalid format")
)

// CompressionLevel specifies the compression level for index serialization.
type CompressionLevel int

Expand Down Expand Up @@ -57,15 +92,21 @@ func writeRGSet(w io.Writer, rs *RGSet) error {
return err
}

func readRGSet(r io.Reader) (*RGSet, error) {
func readRGSet(r io.Reader, maxRGs uint32) (*RGSet, error) {
var numRGs uint32
if err := binary.Read(r, binary.LittleEndian, &numRGs); err != nil {
return nil, err
}
if numRGs > maxRGs {
return nil, errors.Wrapf(ErrInvalidFormat, "rgset numRGs %d exceeds max %d", numRGs, maxRGs)
}
var dataLen uint32
if err := binary.Read(r, binary.LittleEndian, &dataLen); err != nil {
return nil, err
}
if dataLen > maxRGSetSize {
return nil, errors.Wrapf(ErrInvalidFormat, "rgset data length %d exceeds max %d", dataLen, maxRGSetSize)
}
data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil {
return nil, err
Expand Down Expand Up @@ -158,7 +199,7 @@ func EncodeWithLevel(idx *GINIndex, level CompressionLevel) ([]byte, error) {

func Decode(data []byte) (*GINIndex, error) {
if len(data) < 4 {
return nil, errors.New("data too short")
return nil, errors.Wrap(ErrInvalidFormat, "data too short")
}

var decompressed []byte
Expand All @@ -179,17 +220,7 @@ func Decode(data []byte) (*GINIndex, error) {
return nil, errors.Wrap(err, "decompress data")
}
default:
// Legacy format: try zstd decompression without magic (backward compatibility)
decoder, err := zstd.NewReader(nil)
if err != nil {
return nil, errors.Wrap(err, "create zstd decoder")
}
defer decoder.Close()

decompressed, err = decoder.DecodeAll(data, nil)
if err != nil {
return nil, errors.Wrap(err, "decompress data")
}
return nil, errors.Wrapf(ErrInvalidFormat, "unrecognized magic bytes: %q", magic)
}

buf := bytes.NewReader(decompressed)
Expand All @@ -213,11 +244,11 @@ func Decode(data []byte) (*GINIndex, error) {
return nil, errors.Wrap(err, "read string indexes")
}

if err := readStringLengthIndexes(buf, idx); err != nil {
if err := readStringLengthIndexes(buf, idx, idx.Header.NumRowGroups); err != nil {
return nil, errors.Wrap(err, "read string length indexes")
}

if err := readNumericIndexes(buf, idx); err != nil {
if err := readNumericIndexes(buf, idx, idx.Header.NumRowGroups); err != nil {
return nil, errors.Wrap(err, "read numeric indexes")
}

Expand All @@ -234,7 +265,7 @@ func Decode(data []byte) (*GINIndex, error) {
}

if idx.Header.Flags&FlagHasDocIDMap != 0 {
mapping, err := readDocIDMapping(buf)
mapping, err := readDocIDMapping(buf, idx.Header.NumDocs)
if err != nil {
return nil, errors.Wrap(err, "read docid mapping")
}
Expand Down Expand Up @@ -277,11 +308,14 @@ func readHeader(r io.Reader, idx *GINIndex) error {
return err
}
if string(idx.Header.Magic[:]) != MagicBytes {
return errors.New("invalid magic bytes")
return errors.Wrapf(ErrInvalidFormat, "invalid inner magic bytes: %q", string(idx.Header.Magic[:]))
}
if err := binary.Read(r, binary.LittleEndian, &idx.Header.Version); err != nil {
return err
}
if idx.Header.Version != Version {
return errors.Wrapf(ErrVersionMismatch, "got version %d, expected %d", idx.Header.Version, Version)
}
if err := binary.Read(r, binary.LittleEndian, &idx.Header.Flags); err != nil {
return err
}
Expand Down Expand Up @@ -323,6 +357,9 @@ func writePathDirectory(w io.Writer, idx *GINIndex) error {
}

func readPathDirectory(r io.Reader, idx *GINIndex) error {
if idx.Header.NumPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "path count %d exceeds max %d", idx.Header.NumPaths, maxNumPaths)
}
for i := uint32(0); i < idx.Header.NumPaths; i++ {
var entry PathEntry
if err := binary.Read(r, binary.LittleEndian, &entry.PathID); err != nil {
Expand Down Expand Up @@ -383,6 +420,9 @@ func readBloomFilter(r io.Reader) (*BloomFilter, error) {
if err := binary.Read(r, binary.LittleEndian, &numWords); err != nil {
return nil, err
}
if numWords > maxBloomWords {
return nil, errors.Wrapf(ErrInvalidFormat, "bloom filter word count %d exceeds max %d", numWords, maxBloomWords)
}
bits := make([]uint64, numWords)
for i := uint32(0); i < numWords; i++ {
if err := binary.Read(r, binary.LittleEndian, &bits[i]); err != nil {
Expand Down Expand Up @@ -424,6 +464,9 @@ func readStringIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "string index path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
Expand All @@ -433,6 +476,9 @@ func readStringIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numTerms); err != nil {
return err
}
if numTerms > maxTermsPerPath {
return errors.Wrapf(ErrInvalidFormat, "terms count %d for path %d exceeds max %d", numTerms, pathID, maxTermsPerPath)
}
si := &StringIndex{
Terms: make([]string, numTerms),
RGBitmaps: make([]*RGSet, numTerms),
Expand All @@ -448,7 +494,7 @@ func readStringIndexes(r io.Reader, idx *GINIndex) error {
}
si.Terms[j] = string(termBytes)

rgSet, err := readRGSet(r)
rgSet, err := readRGSet(r, idx.Header.NumRowGroups)
if err != nil {
return err
}
Expand Down Expand Up @@ -495,11 +541,14 @@ func writeStringLengthIndexes(w io.Writer, idx *GINIndex) error {
return nil
}

func readStringLengthIndexes(r io.Reader, idx *GINIndex) error {
func readStringLengthIndexes(r io.Reader, idx *GINIndex, maxRGs uint32) error {
var numPaths uint32
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "string length index path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
Expand All @@ -516,6 +565,9 @@ func readStringLengthIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numRGs); err != nil {
return err
}
if numRGs > maxRGs {
return errors.Wrapf(ErrInvalidFormat, "string length index rg count %d for path %d exceeds max %d", numRGs, pathID, maxRGs)
}
sli.RGStats = make([]RGStringLengthStat, numRGs)
for j := uint32(0); j < numRGs; j++ {
if err := binary.Read(r, binary.LittleEndian, &sli.RGStats[j].Min); err != nil {
Expand Down Expand Up @@ -574,11 +626,14 @@ func writeNumericIndexes(w io.Writer, idx *GINIndex) error {
return nil
}

func readNumericIndexes(r io.Reader, idx *GINIndex) error {
func readNumericIndexes(r io.Reader, idx *GINIndex, maxRGs uint32) error {
var numPaths uint32
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "numeric index path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
Expand All @@ -602,6 +657,9 @@ func readNumericIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numRGs); err != nil {
return err
}
if numRGs > maxRGs {
return errors.Wrapf(ErrInvalidFormat, "numeric index rg count %d for path %d exceeds max %d", numRGs, pathID, maxRGs)
}
ni.RGStats = make([]RGNumericStat, numRGs)
for j := uint32(0); j < numRGs; j++ {
if err := binary.Read(r, binary.LittleEndian, &minBits); err != nil {
Expand Down Expand Up @@ -648,16 +706,19 @@ func readNullIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "null index path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
return err
}
nullBitmap, err := readRGSet(r)
nullBitmap, err := readRGSet(r, idx.Header.NumRowGroups)
if err != nil {
return err
}
presentBitmap, err := readRGSet(r)
presentBitmap, err := readRGSet(r, idx.Header.NumRowGroups)
if err != nil {
return err
}
Expand Down Expand Up @@ -716,6 +777,9 @@ func readTrigramIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "trigram index path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
Expand Down Expand Up @@ -749,6 +813,9 @@ func readTrigramIndexes(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numTrigrams); err != nil {
return err
}
if numTrigrams > maxTrigramsPerPath {
return errors.Wrapf(ErrInvalidFormat, "trigram count %d for path %d exceeds max %d", numTrigrams, pathID, maxTrigramsPerPath)
}
for j := uint32(0); j < numTrigrams; j++ {
var trigramLen uint8
if err := binary.Read(r, binary.LittleEndian, &trigramLen); err != nil {
Expand All @@ -758,7 +825,7 @@ func readTrigramIndexes(r io.Reader, idx *GINIndex) error {
if _, err := io.ReadFull(r, trigramBytes); err != nil {
return err
}
rgSet, err := readRGSet(r)
rgSet, err := readRGSet(r, idx.Header.NumRowGroups)
if err != nil {
return err
}
Expand Down Expand Up @@ -796,6 +863,9 @@ func readHyperLogLogs(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numPaths); err != nil {
return err
}
if numPaths > maxNumPaths {
return errors.Wrapf(ErrInvalidFormat, "hll path count %d exceeds max %d", numPaths, maxNumPaths)
}
for i := uint32(0); i < numPaths; i++ {
var pathID uint16
if err := binary.Read(r, binary.LittleEndian, &pathID); err != nil {
Expand All @@ -809,6 +879,9 @@ func readHyperLogLogs(r io.Reader, idx *GINIndex) error {
if err := binary.Read(r, binary.LittleEndian, &numRegisters); err != nil {
return err
}
if numRegisters > maxHLLRegisters {
return errors.Wrapf(ErrInvalidFormat, "hll register count %d exceeds max %d", numRegisters, maxHLLRegisters)
}
registers := make([]uint8, numRegisters)
if _, err := io.ReadFull(r, registers); err != nil {
return err
Expand All @@ -830,11 +903,14 @@ func writeDocIDMapping(w io.Writer, mapping []DocID) error {
return nil
}

func readDocIDMapping(r io.Reader) ([]DocID, error) {
func readDocIDMapping(r io.Reader, maxDocs uint64) ([]DocID, error) {
var numDocs uint64
if err := binary.Read(r, binary.LittleEndian, &numDocs); err != nil {
return nil, err
}
if numDocs > maxDocs {
return nil, errors.Wrapf(ErrInvalidFormat, "docid mapping count %d exceeds max %d", numDocs, maxDocs)
}
mapping := make([]DocID, numDocs)
for i := uint64(0); i < numDocs; i++ {
var docID uint64
Expand Down
Loading