Skip to content

Commit 0ad61e9

Browse files
committed
feat: string serde with validation for kll
1 parent 1b078b4 commit 0ad61e9

2 files changed

Lines changed: 230 additions & 6 deletions

File tree

kll/items_sketch.go

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ func NewKllItemsSketchFromSlice[C comparable](sl []byte, compareFn common.Compar
145145
if err != nil {
146146
return nil, err
147147
}
148+
if serdeWithValidation, ok := any(serde).(common.ItemSketchSerdeWithValidation[C]); ok {
149+
if err := serdeWithValidation.ValidateOne(deserItems[0]); err != nil {
150+
return nil, err
151+
}
152+
}
153+
148154
minItem = deserItems[0]
149155
maxItem = deserItems[0]
150156
hasMinMax = true
@@ -157,19 +163,38 @@ func NewKllItemsSketchFromSlice[C comparable](sl []byte, compareFn common.Compar
157163
return nil, err
158164
}
159165
minItem = deserMinItems[0]
166+
serdeWithValidation, isSerdeWithValidation := any(serde).(common.ItemSketchSerdeWithValidation[C])
167+
if isSerdeWithValidation {
168+
if err := serdeWithValidation.ValidateOne(minItem); err != nil {
169+
return nil, err
170+
}
171+
}
172+
160173
offset += serde.SizeOf(minItem)
161174
deserMaxItems, err := serde.DeserializeManyFromSlice(sl, offset, 1)
162175
if err != nil {
163176
return nil, err
164177
}
165178
maxItem = deserMaxItems[0]
179+
if isSerdeWithValidation {
180+
if err := serdeWithValidation.ValidateOne(maxItem); err != nil {
181+
return nil, err
182+
}
183+
}
184+
166185
hasMinMax = true
167186
offset += serde.SizeOf(maxItem)
168187
numRetained := levelsArr[memVal.numLevels] - levelsArr[0]
169188
deseRetItems, err := serde.DeserializeManyFromSlice(sl, offset, int(numRetained))
170189
if err != nil {
171190
return nil, err
172191
}
192+
if isSerdeWithValidation {
193+
if err := serdeWithValidation.ValidateMany(deseRetItems); err != nil {
194+
return nil, err
195+
}
196+
}
197+
173198
for i := uint32(0); i < numRetained; i++ {
174199
items[i+levelsArr[0]] = deseRetItems[i]
175200
}
@@ -539,8 +564,14 @@ func (s *ItemsSketch[C]) ToSlice() ([]byte, error) {
539564
numLevels := uint8(s.numLevels)
540565
//end of full preamble
541566
lvlsArr := s.getLevelsArray()
542-
minMaxByteArr := s.getMinMaxByteArr()
543-
itemsByteArr := s.getRetainedItemsByteArr()
567+
minMaxByteArr, err := s.getMinMaxByteArr()
568+
if err != nil {
569+
return nil, err
570+
}
571+
itemsByteArr, err := s.getRetainedItemsByteArr()
572+
if err != nil {
573+
return nil, err
574+
}
544575

545576
binary.LittleEndian.PutUint64(bytesOut[8:16], n)
546577
binary.LittleEndian.PutUint16(bytesOut[16:18], minK)
@@ -624,13 +655,22 @@ func (s *ItemsSketch[C]) getMinMaxSizeBytes() int {
624655
return s.serde.SizeOf(s.minItem) + s.serde.SizeOf(s.maxItem)
625656
}
626657

627-
func (s *ItemsSketch[C]) getMinMaxByteArr() []byte {
658+
func (s *ItemsSketch[C]) getMinMaxByteArr() ([]byte, error) {
659+
if serdeWithValidation, ok := any(s.serde).(common.ItemSketchSerdeWithValidation[C]); ok {
660+
if err := serdeWithValidation.ValidateOne(s.minItem); err != nil {
661+
return nil, err
662+
}
663+
if err := serdeWithValidation.ValidateOne(s.maxItem); err != nil {
664+
return nil, err
665+
}
666+
}
667+
628668
minBytes := s.serde.SerializeOneToSlice(s.minItem)
629669
maxBytes := s.serde.SerializeOneToSlice(s.maxItem)
630670
minMaxBytes := make([]byte, len(minBytes)+len(maxBytes))
631671
copy(minMaxBytes, minBytes)
632672
copy(minMaxBytes[len(minBytes):], maxBytes)
633-
return minMaxBytes
673+
return minMaxBytes, nil
634674
}
635675

636676
func (s *ItemsSketch[C]) getSingleItemSizeBytes() (int, error) {
@@ -646,6 +686,11 @@ func (s *ItemsSketch[C]) getSingleItemByteArr() ([]byte, error) {
646686
if err != nil {
647687
return nil, err
648688
}
689+
if serdeWithValidation, ok := any(s.serde).(common.ItemSketchSerdeWithValidation[C]); ok {
690+
if err := serdeWithValidation.ValidateOne(v); err != nil {
691+
return nil, err
692+
}
693+
}
649694
return s.serde.SerializeOneToSlice(v), nil
650695
}
651696

@@ -663,9 +708,14 @@ func (s *ItemsSketch[C]) getRetainedItemsArray() []C {
663708
return outArr
664709
}
665710

666-
func (s *ItemsSketch[C]) getRetainedItemsByteArr() []byte {
711+
func (s *ItemsSketch[C]) getRetainedItemsByteArr() ([]byte, error) {
667712
retArr := s.getRetainedItemsArray()
668-
return s.serde.SerializeManyToSlice(retArr)
713+
if serdeWithValidation, ok := any(s.serde).(common.ItemSketchSerdeWithValidation[C]); ok {
714+
if err := serdeWithValidation.ValidateMany(retArr); err != nil {
715+
return nil, err
716+
}
717+
}
718+
return s.serde.SerializeManyToSlice(retArr), nil
669719
}
670720

671721
func (s *ItemsSketch[C]) getRetainedItemsSizeBytes() int {

kll/items_sketch_test.go

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,3 +1071,177 @@ func TestL0SortDuringMerge(t *testing.T) {
10711071
}
10721072
}
10731073
}
1074+
1075+
func TestToSliceValidateWithStringSerde(t *testing.T) {
1076+
comparator := common.ItemSketchStringComparator(false)
1077+
1078+
tests := []struct {
1079+
name string
1080+
validateUTF8 bool
1081+
items []string
1082+
expectErr bool
1083+
}{
1084+
{
1085+
name: "valid utf8 strings with validation enabled",
1086+
validateUTF8: true,
1087+
items: []string{"hello", "world", "안녕"},
1088+
expectErr: false,
1089+
},
1090+
{
1091+
name: "invalid utf8 string with validation enabled",
1092+
validateUTF8: true,
1093+
items: []string{"hello", string([]byte{0xff, 0xfe})},
1094+
expectErr: true,
1095+
},
1096+
{
1097+
name: "invalid utf8 string with validation disabled",
1098+
validateUTF8: false,
1099+
items: []string{"hello", string([]byte{0xff, 0xfe})},
1100+
expectErr: false,
1101+
},
1102+
{
1103+
name: "empty sketch with validation enabled",
1104+
validateUTF8: true,
1105+
items: []string{},
1106+
expectErr: false,
1107+
},
1108+
{
1109+
name: "single valid item with validation enabled",
1110+
validateUTF8: true,
1111+
items: []string{"hello"},
1112+
expectErr: false,
1113+
},
1114+
{
1115+
name: "single invalid utf8 item with validation enabled",
1116+
validateUTF8: true,
1117+
items: []string{string([]byte{0xff, 0xfe})},
1118+
expectErr: true,
1119+
},
1120+
{
1121+
name: "single invalid utf8 item with validation disabled",
1122+
validateUTF8: false,
1123+
items: []string{string([]byte{0xff, 0xfe})},
1124+
expectErr: false,
1125+
},
1126+
{
1127+
name: "invalid utf8 as min item with validation enabled",
1128+
validateUTF8: true,
1129+
items: []string{string([]byte{0x80}), "zzzz"},
1130+
expectErr: true,
1131+
},
1132+
{
1133+
name: "invalid utf8 as max item with validation enabled",
1134+
validateUTF8: true,
1135+
items: []string{"aaaa", string([]byte{0xff, 0xfe})},
1136+
expectErr: true,
1137+
},
1138+
}
1139+
for _, tt := range tests {
1140+
t.Run(tt.name, func(t *testing.T) {
1141+
serde := common.ItemSketchStringSerDe{ValidateUTF8: tt.validateUTF8}
1142+
sketch, err := NewKllItemsSketchWithDefault[string](comparator, serde)
1143+
assert.NoError(t, err)
1144+
1145+
for _, item := range tt.items {
1146+
sketch.Update(item)
1147+
}
1148+
1149+
_, err = sketch.ToSlice()
1150+
if tt.expectErr {
1151+
assert.ErrorIs(t, err, common.ErrInvalidUTF8)
1152+
} else {
1153+
assert.NoError(t, err)
1154+
}
1155+
})
1156+
}
1157+
}
1158+
1159+
func TestFromSliceValidateWithStringSerde(t *testing.T) {
1160+
comparator := common.ItemSketchStringComparator(false)
1161+
invalidMin := string([]byte{0x80})
1162+
invalidMax := string([]byte{0xff, 0xfe})
1163+
serdeNoValidation := common.ItemSketchStringSerDe{ValidateUTF8: false}
1164+
serdeWithValidation := common.ItemSketchStringSerDe{ValidateUTF8: true}
1165+
1166+
tests := []struct {
1167+
name string
1168+
items []string
1169+
serdeDeserialize common.ItemSketchStringSerDe
1170+
expectErr bool
1171+
}{
1172+
{
1173+
name: "valid utf8 deserialized with validation enabled",
1174+
items: []string{"hello", "world", "안녕"},
1175+
serdeDeserialize: serdeWithValidation,
1176+
expectErr: false,
1177+
},
1178+
{
1179+
name: "invalid utf8 in retained items deserialized with validation enabled",
1180+
items: []string{"hello", invalidMax},
1181+
serdeDeserialize: serdeWithValidation,
1182+
expectErr: true,
1183+
},
1184+
{
1185+
name: "invalid utf8 as min item deserialized with validation enabled",
1186+
items: []string{invalidMin, "zzzz"},
1187+
serdeDeserialize: serdeWithValidation,
1188+
expectErr: true,
1189+
},
1190+
{
1191+
name: "invalid utf8 as max item deserialized with validation enabled",
1192+
items: []string{"aaaa", invalidMax},
1193+
serdeDeserialize: serdeWithValidation,
1194+
expectErr: true,
1195+
},
1196+
{
1197+
name: "invalid utf8 deserialized with validation disabled",
1198+
items: []string{"hello", invalidMax},
1199+
serdeDeserialize: serdeNoValidation,
1200+
expectErr: false,
1201+
},
1202+
{
1203+
name: "empty sketch deserialized with validation enabled",
1204+
items: []string{},
1205+
serdeDeserialize: serdeWithValidation,
1206+
expectErr: false,
1207+
},
1208+
{
1209+
name: "single valid item deserialized with validation enabled",
1210+
items: []string{"hello"},
1211+
serdeDeserialize: serdeWithValidation,
1212+
expectErr: false,
1213+
},
1214+
{
1215+
name: "single invalid utf8 item deserialized with validation enabled",
1216+
items: []string{invalidMax},
1217+
serdeDeserialize: serdeWithValidation,
1218+
expectErr: true,
1219+
},
1220+
{
1221+
name: "single invalid utf8 item deserialized with validation disabled",
1222+
items: []string{invalidMax},
1223+
serdeDeserialize: serdeNoValidation,
1224+
expectErr: false,
1225+
},
1226+
}
1227+
for _, tt := range tests {
1228+
t.Run(tt.name, func(t *testing.T) {
1229+
sketch, err := NewKllItemsSketchWithDefault[string](comparator, serdeNoValidation)
1230+
assert.NoError(t, err)
1231+
1232+
for _, item := range tt.items {
1233+
sketch.Update(item)
1234+
}
1235+
1236+
slc, err := sketch.ToSlice()
1237+
assert.NoError(t, err)
1238+
1239+
_, err = NewKllItemsSketchFromSlice[string](slc, comparator, tt.serdeDeserialize)
1240+
if tt.expectErr {
1241+
assert.ErrorIs(t, err, common.ErrInvalidUTF8)
1242+
} else {
1243+
assert.NoError(t, err)
1244+
}
1245+
})
1246+
}
1247+
}

0 commit comments

Comments
 (0)