diff --git a/nodebuilder/header/service.go b/nodebuilder/header/service.go index ddf71f06b..66f3bfadf 100644 --- a/nodebuilder/header/service.go +++ b/nodebuilder/header/service.go @@ -20,6 +20,11 @@ var tracer = otel.Tracer("header/service") // ErrHeightZero returned when the provided block height is equal to 0. var ErrHeightZero = errors.New("height is equal to 0") +// ErrRangeTooLarge returned when the requested header range exceeds MaxRangeRequestSize. +// Clients should paginate their requests into ranges of at most MaxRangeRequestSize headers. +var ErrRangeTooLarge = fmt.Errorf("header/service: requested range exceeds maximum of %d headers", + libhead.MaxRangeRequestSize) + // Service represents the header Service that can be started / stopped on a node. // Service's main function is to manage its sub-services. Service can contain several // sub-services, such as Exchange, ExchangeServer, Syncer, and so forth. @@ -71,6 +76,10 @@ func (s *Service) GetRangeByHeight( from *header.ExtendedHeader, to uint64, ) (_ []*header.ExtendedHeader, err error) { + if from == nil { + return nil, errors.New("header/service: 'from' header is nil") + } + ctx, span := tracer.Start(ctx, "header/get-range-by-height") defer func() { utils.SetStatusAndEnd(span, err) @@ -79,11 +88,14 @@ func (s *Service) GetRangeByHeight( } }() - // Enforce the same MaxRangeRequestSize limit as the P2P server. // The store fetches headers in range [from.Height()+1, to), so the count is to - from.Height() - 1. - if to > from.Height()+1+libhead.MaxRangeRequestSize { - return nil, fmt.Errorf("header/service: requested range exceeds MaxRangeRequestSize (%d)", - libhead.MaxRangeRequestSize) + switch { + case to <= from.Height()+1: + return nil, fmt.Errorf("header/service: invalid range: 'to' (%d) must be greater than 'from' height + 1 (%d)", + to, from.Height()+1) + // Enforce the same MaxRangeRequestSize limit as the P2P server. + case to > from.Height()+1+libhead.MaxRangeRequestSize: + return nil, ErrRangeTooLarge } log.Infow("getting header range by height", "from", from.Height(), "to", to) diff --git a/nodebuilder/header/service_test.go b/nodebuilder/header/service_test.go index 7d1dd16b2..0a148296e 100644 --- a/nodebuilder/header/service_test.go +++ b/nodebuilder/header/service_test.go @@ -40,7 +40,7 @@ func (d *errorSyncer[H]) SyncWait(context.Context) error { return fmt.Errorf("dummy error") } -func TestGetRangeByHeight_MaxRangeRequestSize(t *testing.T) { +func TestGetRangeByHeight_Validation(t *testing.T) { from := &header.ExtendedHeader{ RawHeader: header.RawHeader{Height: 100}, } @@ -51,7 +51,7 @@ func TestGetRangeByHeight_MaxRangeRequestSize(t *testing.T) { // request 65 headers: from.Height()+1=101 to 166 exclusive = 65 headers to := from.Height() + 1 + libhead.MaxRangeRequestSize + 1 _, err := serv.GetRangeByHeight(context.Background(), from, to) - assert.ErrorContains(t, err, "MaxRangeRequestSize") + assert.ErrorIs(t, err, ErrRangeTooLarge) }) t.Run("accepts range at MaxRangeRequestSize", func(t *testing.T) { @@ -62,4 +62,21 @@ func TestGetRangeByHeight_MaxRangeRequestSize(t *testing.T) { serv.GetRangeByHeight(context.Background(), from, to) //nolint:errcheck }) }) + + t.Run("rejects nil 'from' header", func(t *testing.T) { + assert.NotPanics(t, func() { + _, err := serv.GetRangeByHeight(context.Background(), nil, 100) + assert.ErrorContains(t, err, "'from' header is nil") + }) + }) + + t.Run("rejects reversed range", func(t *testing.T) { + _, err := serv.GetRangeByHeight(context.Background(), from, from.Height()-1) + assert.ErrorContains(t, err, "invalid range") + }) + + t.Run("rejects empty range", func(t *testing.T) { + _, err := serv.GetRangeByHeight(context.Background(), from, from.Height()+1) + assert.ErrorContains(t, err, "invalid range") + }) }