diff --git a/changelog/fernantho_ssz-ql-calculate-generalized-indices.md b/changelog/fernantho_ssz-ql-calculate-generalized-indices.md new file mode 100644 index 000000000000..737e26a617e9 --- /dev/null +++ b/changelog/fernantho_ssz-ql-calculate-generalized-indices.md @@ -0,0 +1,3 @@ +### Added + +- Added GeneralizedIndicesFromPath function to calculate the GIs for a given sszInfo object and a PathElement diff --git a/encoding/ssz/query/BUILD.bazel b/encoding/ssz/query/BUILD.bazel index 79b8ad84c1f1..dd688542e6e4 100644 --- a/encoding/ssz/query/BUILD.bazel +++ b/encoding/ssz/query/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "bitlist.go", "bitvector.go", "container.go", + "generalized_index.go", "list.go", "path.go", "query.go", @@ -24,6 +25,7 @@ go_library( go_test( name = "go_default_test", srcs = [ + "generalized_index_test.go", "path_test.go", "query_test.go", "tag_parser_test.go", diff --git a/encoding/ssz/query/generalized_index.go b/encoding/ssz/query/generalized_index.go new file mode 100644 index 000000000000..049380afe4b0 --- /dev/null +++ b/encoding/ssz/query/generalized_index.go @@ -0,0 +1,358 @@ +package query + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +const ( + bytesPerChunk = 32 + bitsPerChunk = 256 + listBaseIndex = 2 +) + +type Element struct { + length bool + name string + indices *[]uint64 +} + +// GetGeneralizedIndexFromPath calculates the generalized index for a given path. +// To calculate the generalized index, two inputs are needed: +// 1. The sszInfo of the root info, to be able to navigate the SSZ structure +// 2. The path to the field (e.g., "field_a.field_b[3].field_c") +// It walks the path step by step, updating the generalized index at each step. +func GetGeneralizedIndexFromPath(info *sszInfo, path []PathElement) (uint64, error) { + if info == nil { + return 0, errors.New("sszInfo is nil") + } + + // If path element list is empty, no generalized index can be computed. + if len(path) == 0 { + return 0, errors.New("cannot compute generalized index for an empty path") + } + + // Starting from the root generalized index + root := uint64(1) + currentInfo := info + + for _, pathElement := range path { + name := pathElement.Name + element, err := processPathElement(name) + if err != nil { + return 0, err + } + + // If we descend to a basic type, the path cannot continue further + if isBasicType(currentInfo.sszType) { + return 0, fmt.Errorf("cannot descend into basic type %s for path element %q", currentInfo.sszType, name) + } + + // Check that we are in a container to access fields + if currentInfo.sszType != Container { + return 0, fmt.Errorf("indexing requires a container field step first, got %s", currentInfo.sszType) + } + + // Check if path element contains an array index (e.g., field_name[5]) + var idx *uint64 + if element.indices != nil && len(*element.indices) > 0 { + // Note: shortcut, extend to multi-dimensional arrays later + idx = &(*element.indices)[0] + } + + // Retrieve the field position and SSZInfo for the field in the current container + fieldPos, fieldSsz, err := getContainerFieldByName(currentInfo, element.name) + if err != nil { + return 0, fmt.Errorf("container field %q not found: %w", element.name, err) + } + + // Get the chunk count for the current container + chunkCount, err := getChunkCount(currentInfo) + if err != nil { + return 0, fmt.Errorf("chunk count error: %w", err) + } + + // Update the generalized index to point to the specified field + root = updateRoot(root, 1, chunkCount, fieldPos) + currentInfo = fieldSsz + + // Check if a path element is a length field + if element.length { + // Length field is only valid for List and Bitlist types + if fieldSsz.sszType != List && fieldSsz.sszType != Bitlist { + return 0, fmt.Errorf("len() is only supported for List and Bitlist types, got %s", fieldSsz.sszType) + } + currentInfo = &sszInfo{sszType: UintN, fixedSize: 8} + root = updateRoot(root, 1, 2, 1) + continue + } + + if idx != nil { + switch fieldSsz.sszType { + case List: + li, err := fieldSsz.ListInfo() + if err != nil { + return 0, fmt.Errorf("list info error: %w", err) + } + elem, err := li.Element() + if err != nil { + return 0, fmt.Errorf("list element error: %w", err) + } + // Compute chunk position for the element + var chunkPos uint64 + if isBasicType(elem.sszType) { + start := *idx * itemLengthFromInfo(elem) + chunkPos = start / bytesPerChunk + } else { + chunkPos = *idx + } + innerChunkCount, err := getChunkCount(fieldSsz) + if err != nil { + return 0, fmt.Errorf("chunk count error: %w", err) + } + root = updateRoot(root, listBaseIndex, innerChunkCount, chunkPos) + currentInfo = elem + + case Vector: + vi, err := fieldSsz.VectorInfo() + if err != nil { + return 0, fmt.Errorf("vector info error: %w", err) + } + elem, err := vi.Element() + if err != nil { + return 0, fmt.Errorf("vector element error: %w", err) + } + var ( + offset uint64 + multiplier uint64 + ) + if isBasicType(elem.sszType) { + multiplier = nextPowerOfTwo(vi.Length()) + offset = *idx + } else { + innerChunkCount, err := getChunkCount(fieldSsz) + if err != nil { + return 0, fmt.Errorf("chunk count error: %w", err) + } + multiplier = nextPowerOfTwo(innerChunkCount) + offset = *idx + } + root = updateRoot(root, 1, multiplier, offset) + currentInfo = elem + + case Bitlist: + // Bits packed into 256-bit chunks; select the chunk containing the bit + chunkPos := *idx / bitsPerChunk + innerChunkCount, err := getChunkCount(fieldSsz) + if err != nil { + return 0, fmt.Errorf("chunk count error: %w", err) + } + root = updateRoot(root, listBaseIndex, innerChunkCount, chunkPos) + // Bits element is not further descendable; set to basic to guard further steps + currentInfo = &sszInfo{sszType: Boolean, fixedSize: 1} + + case Bitvector: + chunkPos := *idx / bitsPerChunk + innerChunkCount, err := getChunkCount(fieldSsz) + if err != nil { + return 0, fmt.Errorf("chunk count error: %w", err) + } + root = updateRoot(root, 1, innerChunkCount, chunkPos) + // Bits element is not further descendable; set to basic to guard further steps + currentInfo = &sszInfo{sszType: Boolean, fixedSize: 1} + + default: + return 0, fmt.Errorf("indexing not supported for type %s", fieldSsz.sszType) + } + continue + } + } + + return root, nil +} + +// updateRoot computes the new generalized index based on the current root, base index, chunk count, and offset +// base index is typically 1 for containers and 2 for lists +// root = root * base_index * pow2ceil(chunk_count(container)) + fieldPos +func updateRoot(root uint64, baseIndex uint64, chunkCount uint64, offset uint64) uint64 { + return root*baseIndex*nextPowerOfTwo(chunkCount) + offset +} + +// isBasicType checks if the SSZType is a basic type +func isBasicType(sszType SSZType) bool { + switch sszType { + case UintN, Byte, Boolean: + return true + default: + return false + } +} + +// getChunkCount returns the number of chunks for the given SSZInfo (equivalent to chunk_count in the spec) +func getChunkCount(info *sszInfo) (uint64, error) { + switch info.sszType { + case UintN, Byte, Boolean: + return 1, nil + case Container: + containerInfo, err := info.ContainerInfo() + if err != nil { + return 0, err + } + return uint64(len(containerInfo.order)), nil + case List: + listInfo, err := info.ListInfo() + if err != nil { + return 0, err + } + // For Lists with basic element types, multiple elements can be packed into 32-byte chunks + elementInfo, err := listInfo.Element() + if err != nil { + return 0, err + } + elemLength := itemLengthFromInfo(elementInfo) + return (listInfo.Limit()*uint64(elemLength) + 31) / bytesPerChunk, nil + case Vector: + vectorInfo, err := info.VectorInfo() + if err != nil { + return 0, err + } + // For Vectors with basic element types, multiple elements can be packed into 32-byte chunks + elementInfo, err := vectorInfo.Element() + if err != nil { + return 0, err + } + elemLength := itemLengthFromInfo(elementInfo) + return (vectorInfo.Length()*uint64(elemLength) + 31) / bytesPerChunk, nil + case Bitlist: + bitlistInfo, err := info.BitlistInfo() + if err != nil { + return 0, err + } + return (bitlistInfo.Limit() + 255) / bitsPerChunk, nil // Bits are packed into 256-bit chunks + case Bitvector: + vectorInfo, err := info.BitvectorInfo() + if err != nil { + return 0, err + } + return (vectorInfo.Length() + 255) / bitsPerChunk, nil // Bits are packed into 256-bit chunks + default: + return 0, errors.New("unsupported SSZ type for chunk count calculation") + } +} + +// getContainerFieldByName finds a container field by name. +func getContainerFieldByName(info *sszInfo, fieldName string) (uint64, *sszInfo, error) { + containerInfo, err := info.ContainerInfo() + if err != nil { + return 0, nil, err + } + + for index, name := range containerInfo.order { + if name == fieldName { + fieldInfo := containerInfo.fields[name] + if fieldInfo == nil || fieldInfo.sszInfo == nil { + return 0, nil, fmt.Errorf("field %q has no ssz info", name) + } + return uint64(index), fieldInfo.sszInfo, nil + } + } + + return 0, nil, fmt.Errorf("field %q not found", fieldName) +} + +// itemLengthFromInfo calculates the byte length of an SSZ item based on its type information. +// For basic SSZ types (uint8, uint16, uint32, uint64, bool, etc.), it returns the actual +// size of the type in bytes. For complex types (containers, lists, vectors), it returns +// bytesPerChunk which represents the standard SSZ chunk size (32 bytes) used for +// Merkle tree operations in the SSZ serialization format. +func itemLengthFromInfo(info *sszInfo) uint64 { + if isBasicType(info.sszType) { + return info.Size() + } + return bytesPerChunk +} + +// Helpers for input processing + +// processPathElement processes a path element string and returns an Element struct +func processPathElement(elementStr string) (Element, error) { + element := Element{} + + // Processing element string + processingField := elementStr + + re := regexp.MustCompile(`^\s*len\s*\(\s*([^)]+?)\s*\)\s*$`) + matches := re.FindStringSubmatch(processingField) + if len(matches) == 2 { + element.length = true + // Extract the inner expression between len( and ) and continue parsing on that + processingField = matches[1] + } + + // Default name is the full working string (may be updated below if it contains indices) + element.name = processingField + + if strings.Contains(processingField, "[") { + // Split into field and indices, e.g., "array[0][1]" -> name:"array", indices:{0,1} + element.name = extractFieldName(processingField) + indices, err := extractArrayIndices(processingField) + if err != nil { + return Element{}, err + } + element.indices = &indices + } + + return element, nil +} + +// extractFieldName extracts the field name from a path element name (removes array indices) +// For example: "field_name[5]" returns "field_name" +func extractFieldName(name string) string { + if idx := strings.Index(name, "["); idx != -1 { + return name[:idx] + } + return strings.ToLower(name) +} + +// extractArrayIndices returns every bracketed, non-negative index in the name, +// e.g. "array[0][1]" -> []uint64{0, 1}. Errors if none are found or if any index is invalid. +func extractArrayIndices(name string) ([]uint64, error) { + // Match all bracketed content, then we'll parse as unsigned to catch negatives explicitly + re := regexp.MustCompile(`\[\s*([^\]]+)\s*\]`) + matches := re.FindAllStringSubmatch(name, -1) + + if len(matches) == 0 { + return nil, errors.New("no array indices found") + } + + indices := make([]uint64, 0, len(matches)) + for _, m := range matches { + raw := strings.TrimSpace(m[1]) + // Forbid signs explicitly; we want a clear error similar to ParseUint's message + if strings.HasPrefix(raw, "-") { + return nil, fmt.Errorf("cannot process negative indices %q", raw) + } + idx, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid array index: %w", err) + } + indices = append(indices, idx) + } + return indices, nil +} + +// Copied from fastssz +// Modified to return uint64 +func nextPowerOfTwo(v uint64) uint64 { + v-- + v |= v >> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + v++ + return uint64(v) +} diff --git a/encoding/ssz/query/generalized_index_test.go b/encoding/ssz/query/generalized_index_test.go new file mode 100644 index 000000000000..fa36bf9e14b5 --- /dev/null +++ b/encoding/ssz/query/generalized_index_test.go @@ -0,0 +1,350 @@ +package query_test + +import ( + "strings" + "testing" + + "github.com/OffchainLabs/prysm/v6/encoding/ssz/query" + sszquerypb "github.com/OffchainLabs/prysm/v6/proto/ssz_query" + "github.com/OffchainLabs/prysm/v6/testing/require" +) + +func TestGetIndicesFromPath_FixedNestedContainer(t *testing.T) { + fixedNestedContainer := &sszquerypb.FixedNestedContainer{} + + info, err := query.AnalyzeObject(fixedNestedContainer) + require.NoError(t, err) + require.NotNil(t, info, "Expected non-nil SSZ info") + + testCases := []struct { + name string + path string + expectedIndex uint64 + expectError bool + errorMessage string + }{ + { + name: "Value1 field", + path: ".value1", + expectedIndex: 2, + expectError: false, + }, + { + name: "Value3 field", + path: ".value3", + expectError: true, + errorMessage: "field \"value3\" not found", + }, + { + name: "Basic field cannot descend", + path: "value1.value1", + expectedIndex: 0, + expectError: true, + errorMessage: "cannot descend into basic type", + }, + { + name: "Indexing without container step", + path: "value2.value2[0]", + expectedIndex: 0, + expectError: true, + errorMessage: "indexing requires a container field step first", + }, + { + name: "Value2 field", + path: "value2", + expectedIndex: 3, + expectError: false, + }, + { + name: "Value2 -> element[0]", + path: "value2[0]", + expectedIndex: 96, + expectError: false, + }, + { + name: "Value2 -> element[31]", + path: "value2[31]", + expectedIndex: 127, + expectError: false, + }, + { + name: "Empty path error", + path: "", + expectedIndex: 0, + expectError: true, + errorMessage: "empty path", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provingFields, err := query.ParsePath(tc.path) + require.NoError(t, err) + + actualIndex, err := query.GetGeneralizedIndexFromPath(info, provingFields) + + if tc.expectError { + require.NotNil(t, err) + if tc.errorMessage != "" { + if !strings.Contains(err.Error(), tc.errorMessage) { + t.Errorf("Expected error message to contain '%s', but got: %s", tc.errorMessage, err.Error()) + } + } + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, actualIndex, "Generalized index mismatch for path: %s", tc.path) + t.Logf("Path: %s -> Generalized Index: %v", tc.path, actualIndex) + } + }) + } +} + +func TestGetIndicesFromPath_VariableTestContainer(t *testing.T) { + testSpec := &sszquerypb.VariableTestContainer{} + info, err := query.AnalyzeObject(testSpec) + require.NoError(t, err) + require.NotNil(t, info, "Expected non-nil SSZ info") + + testCases := []struct { + name string + path string + expectedIndex uint64 + expectError bool + errorMessage string + }{ + { + name: "leading_field", + path: "leading_field", + expectedIndex: 8, + expectError: false, + }, + { + name: ".leading_field", + path: ".leading_field", + expectedIndex: 8, + expectError: false, + }, + { + name: "field_list_uint64", + path: "field_list_uint64", + expectedIndex: 9, + expectError: false, + }, + { + name: "len(field_list_uint64)", + path: "len(field_list_uint64)", + expectedIndex: 19, + expectError: false, + }, + { + name: "bitlist_field", + path: "bitlist_field", + expectedIndex: 13, + expectError: false, + }, + { + name: "bitlist_field[0]", + path: "bitlist_field[0]", + expectedIndex: 208, + expectError: false, + }, + { + name: "bitlist_field[1]", + path: "bitlist_field[1]", + expectedIndex: 208, + expectError: false, + }, + { + name: "bitlist_field[-1]", + path: "bitlist_field[-1]", + expectError: true, + errorMessage: "cannot process negative indices \"-1\"", + }, + { + name: "len(bitlist_field)", + path: "len(bitlist_field)", + expectedIndex: 27, + expectError: false, + }, + { + name: "len(trailing_field)", + path: "len(trailing_field)", + expectError: true, + errorMessage: "len() is only supported for List and Bitlist types, got Vector", + }, + { + name: "field_list_container[0]", + path: "field_list_container[0]", + expectedIndex: 2560, + expectError: false, + }, + { + name: "field_list_uint64[0]", + path: "field_list_uint64[0]", + expectedIndex: 9216, + expectError: false, + }, + { + name: "field_list_uint64[2047]", + path: "field_list_uint64[2047]", + expectedIndex: 9727, + expectError: false, + }, + { + name: "nested", + path: "nested", + expectedIndex: 12, + expectError: false, + }, + { + name: "nested.field_list_uint64[10]", + path: "nested.field_list_uint64[10]", + expectedIndex: 3138, + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provingFields, err := query.ParsePath(tc.path) + require.NoError(t, err) + + actualIndex, err := query.GetGeneralizedIndexFromPath(info, provingFields) + + if tc.expectError { + require.NotNil(t, err) + if tc.errorMessage != "" { + if !strings.Contains(err.Error(), tc.errorMessage) { + t.Errorf("Expected error message to contain '%s', but got: %s", tc.errorMessage, err.Error()) + } + } + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, actualIndex, "Generalized index mismatch for path: %s", tc.path) + t.Logf("Path: %s -> Generalized Index: %v", tc.path, actualIndex) + } + }) + } +} + +func TestGetIndicesFromPath_FixedTestContainer(t *testing.T) { + testSpec := &sszquerypb.FixedTestContainer{} + info, err := query.AnalyzeObject(testSpec) + require.NoError(t, err) + require.NotNil(t, info, "Expected non-nil SSZ info") + + testCases := []struct { + name string + path string + expectedIndex uint64 + expectError bool + errorMessage string + }{ + { + name: "field_uint32", + path: "field_uint32", + expectedIndex: 16, + expectError: false, + }, + { + name: ".field_uint64", + path: ".field_uint64", + expectedIndex: 17, + expectError: false, + }, + { + name: "field_bool", + path: "field_bool", + expectedIndex: 18, + expectError: false, + }, + { + name: "field_bytes32", + path: "field_bytes32", + expectedIndex: 19, + expectError: false, + }, + { + name: "nested", + path: "nested", + expectedIndex: 20, + expectError: false, + }, + { + name: "vector_field", + path: "vector_field", + expectedIndex: 21, + expectError: false, + }, + { + name: "two_dimension_bytes_field", + path: "two_dimension_bytes_field", + expectedIndex: 22, + expectError: false, + }, + { + name: "bitvector64_field", + path: "bitvector64_field", + expectedIndex: 23, + expectError: false, + }, + { + name: "bitvector512_field", + path: "bitvector512_field", + expectedIndex: 24, + expectError: false, + }, + { + name: "bitvector64_field[0]", + path: "bitvector64_field[0]", + expectedIndex: 23, + expectError: false, + }, + { + name: "bitvector64_field[63]", + path: "bitvector64_field[63]", + expectedIndex: 23, + expectError: false, + }, + { + name: "bitvector512_field[0]", + path: "bitvector512_field[0]", + expectedIndex: 48, + expectError: false, + }, + { + name: "bitvector512_field[511]", + path: "bitvector512_field[511]", + expectedIndex: 49, + expectError: false, + }, + { + name: "trailing_field", + path: "trailing_field", + expectedIndex: 25, + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provingFields, err := query.ParsePath(tc.path) + require.NoError(t, err) + + actualIndex, err := query.GetGeneralizedIndexFromPath(info, provingFields) + + if tc.expectError { + require.NotNil(t, err) + if tc.errorMessage != "" { + if !strings.Contains(err.Error(), tc.errorMessage) { + t.Errorf("Expected error message to contain '%s', but got: %s", tc.errorMessage, err.Error()) + } + } + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, actualIndex, "Generalized index mismatch for path: %s", tc.path) + t.Logf("Path: %s -> Generalized Index: %v", tc.path, actualIndex) + } + }) + } +}