diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5fbf9fc..a63c6c5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,10 +32,8 @@ jobs: run: make test - name: Coverage run: make test-cov - - name: Generate coverage report in XML format - run: make test-xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml diff --git a/Makefile b/Makefile index 77f04bc..a7a39ee 100644 --- a/Makefile +++ b/Makefile @@ -15,8 +15,7 @@ vet: @$(call run-go-mod-dir,go vet ./...,"go vet") lint: golangci-lint - $(GOLANGCI_LINT) run --timeout=10m -v - + $(GOLANGCI_LINT) run --timeout=10m -v --fix .PHONY: tidy tidy: @@ -32,21 +31,11 @@ generate: mockery protoc test: @$(call run-go-mod-dir,go test -race -covermode=atomic -coverprofile=coverage.out ./...,"go test") -test-cov: gocov - @$(call run-go-mod-dir-exclude,$(GOCOV) convert coverage.out > coverage.json,$(EXCLUDE_GO_MOD_DIRS),"gocov convert") - @$(call run-go-mod-dir-exclude,$(GOCOV) convert coverage.out | $(GOCOV) report,$(EXCLUDE_GO_MOD_DIRS),"gocov report") - -test-xml: test-cov gocov-xml - @jq -n '{ Packages: [ inputs.Packages ] | add }' $(shell find . -type f -name 'coverage.json' | sort) | $(GOCOVXML) > coverage.xml - -.PHONY: test-html - -test-html: test-cov gocov-html - @jq -n '{ Packages: [ inputs.Packages ] | add }' $(shell find . -type f -name 'coverage.json' | sort) | $(GOCOVHTML) -t kit -r > coverage.html - @open coverage.html +test-cov: + @$(call run-go-mod-dir-exclude,go tool cover -func=coverage.out,$(EXCLUDE_GO_MOD_DIRS),"go tool cover") .PHONY: check -check: fmt vet lint +check: tidy fmt vet lint @git diff --quiet || test $$(git diff --name-only | grep -v -e 'go.mod$$' -e 'go.sum$$' | wc -l) -eq 0 || ( echo "The following changes (result of code generators and code checks) have been detected:" && git --no-pager diff && false ) # fail if Git working tree is dirty # ========= Helpers =========== @@ -55,18 +44,6 @@ GOLANGCI_LINT = $(BIN_DIR)/golangci-lint golangci-lint: $(call go-get-tool,$(GOLANGCI_LINT),github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest) -GOCOV = $(BIN_DIR)/gocov -gocov: - $(call go-get-tool,$(GOCOV),github.com/axw/gocov/gocov@v1.0.0) - -GOCOVXML = $(BIN_DIR)/gocov-xml -gocov-xml: - $(call go-get-tool,$(GOCOVXML),github.com/AlekSi/gocov-xml@v1.0.0) - -GOCOVHTML = $(BIN_DIR)/gocov-html -gocov-html: - $(call go-get-tool,$(GOCOVHTML),github.com/matm/gocov-html/cmd/gocov-html@v1.4.0) - MOCKERY = $(BIN_DIR)/mockery mockery: $(call go-get-tool,$(MOCKERY),github.com/vektra/mockery/v2@v2.43.0) diff --git a/xapi/doc.go b/xapi/doc.go new file mode 100644 index 0000000..9f0a820 --- /dev/null +++ b/xapi/doc.go @@ -0,0 +1,53 @@ +// Package xapi provides a type-safe lightweight HTTP API framework for Go. +// +// Most HTTP handlers follow the same pattern - decode JSON, extract headers/params, +// validate, call business logic, encode response. xapi codifies that pattern using +// generics, so you write less but get more type safety. Your request and response +// types define the API contract. The optional interfaces provide flexibility when needed. +// +// The result: handlers that are mostly business logic, with HTTP operations abstracted +// away into a lightweight framework. You can use it with your existing HTTP router and +// server, keeping all existing middlewares and error handling. +// +// # Core Types +// +// [Endpoint] is the main type that wraps your [EndpointHandler] and applies middleware +// and error handling. Create endpoints using [NewEndpoint] with your handler and optional +// configuration via [EndpointOption] values. +// +// [EndpointFunc] is a function type that implements [EndpointHandler], providing a +// convenient way to create handlers from functions. +// +// # Optional Interfaces +// +// xapi defines four optional interfaces. Implement them on request and response types +// only when needed: +// +// [Validator] runs after JSON decoding to validate the request. You can use any validation +// library here. +// +// [Extracter] pulls data from the HTTP request that isn't in the JSON body, such as headers, +// route path params, or query strings. +// +// [StatusSetter] controls the HTTP status code. The default is 200, but you can override it +// to return 201 for creation, 204 for no content, etc. +// +// [RawWriter] bypasses JSON encoding entirely for HTML, or binary responses. +// Use this when you need full control over the response format. +// +// # Middleware +// +// Middleware works exactly like standard http.Handler middleware. Any middleware you're +// already using will work. Stack them in the order you need using [WithMiddleware]. They +// wrap the endpoint cleanly, keeping auth, logging, and metrics separate from your +// business logic. Use [MiddlewareFunc] to convert functions to middleware, or implement +// [MiddlewareHandler] for custom middleware types. +// +// # Error Handling +// +// Default behavior is a 500 status with the error text. Customize this using +// [WithErrorHandler] to distinguish validation errors from auth failures, map them to +// appropriate status codes, and format them consistently. Implement the [ErrorHandler] +// interface or use [ErrorFunc] for simple function-based handlers. The default error +// handling is provided by [DefaultErrorHandler]. +package xapi diff --git a/xapi/endpoint.go b/xapi/endpoint.go new file mode 100644 index 0000000..82ebba1 --- /dev/null +++ b/xapi/endpoint.go @@ -0,0 +1,132 @@ +package xapi + +import ( + "context" + "encoding/json" + "net/http" +) + +// EndpointHandler defines the interface for handling endpoint requests. +type EndpointHandler[TReq, TRes any] interface { + Handle(ctx context.Context, req *TReq) (*TRes, error) +} + +// EndpointFunc is a function type that implements EndpointHandler. +type EndpointFunc[TReq, TRes any] func(ctx context.Context, req *TReq) (*TRes, error) + +// Handle implements the EndpointHandler interface. +func (e EndpointFunc[TReq, TRes]) Handle(ctx context.Context, req *TReq) (*TRes, error) { + return e(ctx, req) +} + +// Extracter allows extracting additional data from the HTTP request, +// such as headers, query params, etc. +type Extracter interface { + Extract(r *http.Request) error +} + +// Validator allows validating endpoint requests. +type Validator interface { + Validate() error +} + +// StatusSetter allows setting a custom HTTP status code for the response. +type StatusSetter interface { + StatusCode() int +} + +// RawWriter allows writing raw data to the HTTP response instead of +// the default JSON encoder. +type RawWriter interface { + Write(w http.ResponseWriter) error +} + +// Endpoint represents a type-safe HTTP endpoint with middleware and error handling. +type Endpoint[TReq, TRes any] struct { + handler EndpointHandler[TReq, TRes] + opts *options +} + +// NewEndpoint creates a new Endpoint with the given handler and options. +func NewEndpoint[TReq, TRes any](handler EndpointHandler[TReq, TRes], opts ...EndpointOption) *Endpoint[TReq, TRes] { + e := &Endpoint[TReq, TRes]{ + handler: handler, + opts: &options{ + middleware: MiddlewareStack{}, + errorHandler: ErrorFunc(DefaultErrorHandler), + }, + } + + for _, option := range opts { + option.apply(e.opts) + } + + return e +} + +// Handler returns an http.Handler that processes requests for this endpoint. +func (e *Endpoint[TReq, TRes]) Handler() http.Handler { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req TReq + + if r.Body != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + + defer r.Body.Close() + } + + if extracter, ok := any(&req).(Extracter); ok { + if err := extracter.Extract(r); err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + } + + if validator, ok := any(&req).(Validator); ok { + if err := validator.Validate(); err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + } + + res, err := e.handler.Handle(r.Context(), &req) + if err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + + if rawWriter, ok := any(res).(RawWriter); ok { + if err := rawWriter.Write(w); err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + + return + } + + statusCode := http.StatusOK + + if statusSetter, ok := any(res).(StatusSetter); ok { + statusCode = statusSetter.StatusCode() + } + + resBody, err := json.Marshal(res) + if err != nil { + e.opts.errorHandler.HandleError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(statusCode) + w.Write(resBody) + }) + + if len(e.opts.middleware) > 0 { + return e.opts.middleware.Middleware(h) + } + + return h +} diff --git a/xapi/endpoint_test.go b/xapi/endpoint_test.go new file mode 100644 index 0000000..c44a7f4 --- /dev/null +++ b/xapi/endpoint_test.go @@ -0,0 +1,528 @@ +package xapi + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEndpoint_Handler(t *testing.T) { + t.Parallel() + + t.Run("BasicEndpoint", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return &BasicResponse{ + Message: "Hello " + req.Name, + ID: 123, + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "World"}))) + req.Header.Set("Content-Type", "application/json") + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "application/json; charset=utf-8", rec.Header().Get("Content-Type")) + assert.JSONEq(t, `{ + "message": "Hello World", + "id": 123 + }`, rec.Body.String()) + }) + + t.Run("WithValidation", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[ValidatedRequest, BasicResponse]( + func(ctx context.Context, req *ValidatedRequest) (*BasicResponse, error) { + return &BasicResponse{ + Message: "Valid request", + ID: 456, + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &ValidatedRequest{Name: ""}))) // Empty name should fail validation + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "name is required") + }) + + t.Run("WithExtraction", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[ExtractedRequest, BasicResponse]( + func(ctx context.Context, req *ExtractedRequest) (*BasicResponse, error) { + return &BasicResponse{ + Message: fmt.Sprintf("Hello %s from %s", req.Name, req.Language), + ID: 789, + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &ExtractedRequest{Name: "World"}))) + req.Header.Set("Language", "en-US") + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "application/json; charset=utf-8", rec.Header().Get("Content-Type")) + assert.JSONEq(t, `{ + "message": "Hello World from en-US", + "id": 789 + }`, rec.Body.String()) + }) + + t.Run("WithCustomStatusCode", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, StatusResponse]( + func(ctx context.Context, req *BasicRequest) (*StatusResponse, error) { + return &StatusResponse{ + Message: "Created successfully", + ID: 999, + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusCreated, rec.Result().StatusCode) + assert.Equal(t, "application/json; charset=utf-8", rec.Header().Get("Content-Type")) + assert.JSONEq(t, `{ + "message": "Created successfully", + "id": 999 + }`, rec.Body.String()) + }) + + t.Run("WithRawWriter", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, RawResponse]( + func(ctx context.Context, req *BasicRequest) (*RawResponse, error) { + return &RawResponse{ + Content: fmt.Sprintf("

Hello %s

", req.Name), + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "World"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "text/html", rec.Header().Get("Content-Type")) + assert.Equal(t, "

Hello World

", rec.Body.String()) + }) + + t.Run("HandlerError", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return nil, errors.New("handler error") + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "handler error") + }) + + t.Run("JSONDecodeError", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return &BasicResponse{Message: "Success"}, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest( + http.MethodPost, "/test", + strings.NewReader("invalid json"), + ) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "invalid character") + }) + + t.Run("JSONEncodeError", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, InvalidJSONResponse]( + func(ctx context.Context, req *BasicRequest) (*InvalidJSONResponse, error) { + return &InvalidJSONResponse{ + Channel: make(chan int), // Channels can't be JSON encoded + }, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + // With the improved implementation, JSON encoding errors are caught before headers are written + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "json: unsupported type") + }) + + t.Run("ExtractionError", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[ExtractionErrorRequest, BasicResponse]( + func(ctx context.Context, req *ExtractionErrorRequest) (*BasicResponse, error) { + return &BasicResponse{Message: "Success"}, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &ExtractionErrorRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "extraction error") + }) + + t.Run("RawWriterError", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, RawWriterErrorResponse]( + func(ctx context.Context, req *BasicRequest) (*RawWriterErrorResponse, error) { + return &RawWriterErrorResponse{}, nil + }, + ) + + endpoint := NewEndpoint(handler) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "raw writer error") + }) +} + +func TestEndpoint_WithMiddleware(t *testing.T) { + t.Parallel() + + t.Run("SingleMiddleware", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return &BasicResponse{Message: "Success"}, nil + }, + ) + + middleware := MiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware", "applied") + next.ServeHTTP(w, r) + }) + }) + + endpoint := NewEndpoint(handler, WithMiddleware(middleware)) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "applied", rec.Header().Get("X-Middleware")) + }) + + t.Run("MultipleMiddleware", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return &BasicResponse{Message: "Success"}, nil + }, + ) + + middleware1 := MiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware-1", "first") + next.ServeHTTP(w, r) + }) + }) + + middleware2 := MiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware-2", "second") + next.ServeHTTP(w, r) + }) + }) + + endpoint := NewEndpoint(handler, WithMiddleware(middleware1, middleware2)) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "first", rec.Header().Get("X-Middleware-1")) + assert.Equal(t, "second", rec.Header().Get("X-Middleware-2")) + }) + + t.Run("MiddlewareOrder", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return &BasicResponse{Message: "Success"}, nil + }, + ) + + order := []string{} + + // Middleware should be applied in reverse order (last added first) + middleware1 := MiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + order = append(order, "1") + }) + }) + + middleware2 := MiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + order = append(order, "2") + }) + }) + + endpoint := NewEndpoint(handler, WithMiddleware(middleware1, middleware2)) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.EqualValues(t, []string{"2", "1"}, order) + }) +} + +func TestEndpoint_WithCustomErrorHandler(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + return nil, errors.New("custom error") + }, + ) + + customErrorHandler := ErrorFunc(func(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{ + "error": "Custom: " + err.Error(), + }) + }) + + endpoint := NewEndpoint(handler, WithErrorHandler(customErrorHandler)) + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Result().StatusCode) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var response map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "Custom: custom error", response["error"]) +} + +func TestEndpoint_ContextCancellation(t *testing.T) { + t.Parallel() + + t.Run("ContextCancelled", func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + + handler := EndpointFunc[BasicRequest, BasicResponse]( + func(ctx context.Context, req *BasicRequest) (*BasicResponse, error) { + // Simulate context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + return &BasicResponse{Message: "Success"}, nil + } + }, + ) + + endpoint := NewEndpoint(handler) + + // Create a request with a cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + req := httptest.NewRequest(http.MethodPost, "/test", + bytes.NewBuffer(mustMarshalJSON(t, &BasicRequest{Name: "Test"}))) + req = req.WithContext(ctx) + + endpoint.Handler().ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Result().StatusCode) + assert.Contains(t, rec.Body.String(), "context canceled") + }) +} + +// Test types and implementations + +type BasicRequest struct { + Name string `json:"name"` +} + +type BasicResponse struct { + Message string `json:"message"` + ID int `json:"id"` +} + +type ValidatedRequest struct { + Name string `json:"name"` +} + +func (r *ValidatedRequest) Validate() error { + if r.Name == "" { + return errors.New("name is required") + } + return nil +} + +type ExtractedRequest struct { + Name string `json:"name"` + Language string `json:"-"` +} + +func (r *ExtractedRequest) Extract(req *http.Request) error { + r.Language = req.Header.Get("Language") + return nil +} + +type StatusResponse struct { + Message string `json:"message"` + ID int `json:"id"` +} + +func (r *StatusResponse) StatusCode() int { + return http.StatusCreated +} + +type RawResponse struct { + Content string `json:"-"` +} + +func (r *RawResponse) Write(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(r.Content)) + return err +} + +type InvalidJSONResponse struct { + Channel chan int `json:"channel"` +} + +func (r *InvalidJSONResponse) MarshalJSON() ([]byte, error) { + // Force an error during JSON marshaling + return nil, errors.New("json: unsupported type") +} + +type ExtractionErrorRequest struct { + Name string `json:"name"` +} + +func (r *ExtractionErrorRequest) Extract(req *http.Request) error { + return errors.New("extraction error") +} + +type RawWriterErrorResponse struct{} + +func (r *RawWriterErrorResponse) Write(w http.ResponseWriter) error { + return errors.New("raw writer error") +} + +// Helper functions + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + + return data +} diff --git a/xapi/error.go b/xapi/error.go new file mode 100644 index 0000000..075610a --- /dev/null +++ b/xapi/error.go @@ -0,0 +1,38 @@ +package xapi + +import ( + "encoding/json" + "errors" + "net/http" +) + +// ErrorHandler defines the interface for handling errors in HTTP responses. +type ErrorHandler interface { + HandleError(w http.ResponseWriter, err error) +} + +// ErrorFunc is a function type that implements ErrorHandler. +type ErrorFunc func(w http.ResponseWriter, err error) + +// HandleError implements the ErrorHandler interface. +func (e ErrorFunc) HandleError(w http.ResponseWriter, err error) { + e(w, err) +} + +// DefaultErrorHandler provides default error handling for common JSON errors. +func DefaultErrorHandler(w http.ResponseWriter, err error) { + var syntaxError *json.SyntaxError + var unmarshalTypeError *json.UnmarshalTypeError + var invalidUnmarshalError *json.InvalidUnmarshalError + + switch { + case errors.As(err, &syntaxError): + http.Error(w, syntaxError.Error(), http.StatusBadRequest) + case errors.As(err, &unmarshalTypeError): + http.Error(w, unmarshalTypeError.Error(), http.StatusBadRequest) + case errors.As(err, &invalidUnmarshalError): + http.Error(w, invalidUnmarshalError.Error(), http.StatusBadRequest) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/xapi/example_test.go b/xapi/example_test.go new file mode 100644 index 0000000..702c17d --- /dev/null +++ b/xapi/example_test.go @@ -0,0 +1,213 @@ +package xapi_test + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + + "github.com/gojekfarm/xtools/xapi" +) + +type CreateUserRequest struct { + Name string `json:"name"` + Email string `json:"email"` + Language string `json:"-"` +} + +func (user *CreateUserRequest) Validate() error { + if user.Name == "" { + return fmt.Errorf("name is required") + } + if user.Email == "" { + return fmt.Errorf("email is required") + } + return nil +} + +func (user *CreateUserRequest) Extract(r *http.Request) error { + user.Language = r.Header.Get("Language") + return nil +} + +type CreateUserResponse struct { + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + Language string `json:"language"` +} + +func (user *CreateUserResponse) StatusCode() int { + return http.StatusCreated +} + +func ExampleEndpoint_basic() { + createUser := xapi.EndpointFunc[CreateUserRequest, CreateUserResponse]( + func(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) { + // Simulate user creation logic + return &CreateUserResponse{ + ID: 1, + Name: req.Name, + Email: req.Email, + Language: req.Language, + }, nil + }, + ) + + endpoint := xapi.NewEndpoint(createUser) + + http.Handle("/users", endpoint.Handler()) + log.Println("Server starting on :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func ExampleEndpoint_customErrorHandler() { + type GetUserRequest struct { + ID int `json:"id"` + } + + type GetUserResponse struct { + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + } + + customErrorHandler := xapi.ErrorFunc(func(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + + errorResponse := map[string]string{ + "error": err.Error(), + } + + json.NewEncoder(w).Encode(errorResponse) + }) + + getUser := xapi.EndpointFunc[GetUserRequest, GetUserResponse]( + func(ctx context.Context, req *GetUserRequest) (*GetUserResponse, error) { + if req.ID <= 0 { + return nil, fmt.Errorf("invalid user ID: %d", req.ID) + } + + // Simulate user lookup + return &GetUserResponse{ + ID: req.ID, + Name: "John Doe", + Email: "john@example.com", + }, nil + }, + ) + + // Create endpoint with custom error handler + endpoint := xapi.NewEndpoint( + getUser, + xapi.WithErrorHandler(customErrorHandler), + ) + + http.Handle("/users/", endpoint.Handler()) +} + +func ExampleEndpoint_withMiddleware() { + type GetDataRequest struct { + ID string `json:"id"` + } + + type GetDataResponse struct { + ID string `json:"id"` + Data string `json:"data"` + } + + // Authentication middleware + authMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Simulate token validation + if token != "Bearer valid-token" { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } + + // Rate limiting middleware + rateLimitMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate rate limiting logic + w.Header().Set("X-RateLimit-Limit", "100") + w.Header().Set("X-RateLimit-Remaining", "99") + + next.ServeHTTP(w, r) + }) + } + + getData := xapi.EndpointFunc[GetDataRequest, GetDataResponse]( + func(ctx context.Context, req *GetDataRequest) (*GetDataResponse, error) { + requestID := ctx.Value("requestID") + log.Printf("Processing request %s for ID: %s", requestID, req.ID) + + return &GetDataResponse{ + ID: req.ID, + Data: fmt.Sprintf("Data for %s", req.ID), + }, nil + }, + ) + + endpoint := xapi.NewEndpoint( + getData, + xapi.WithMiddleware( + xapi.MiddlewareFunc(rateLimitMiddleware), + xapi.MiddlewareFunc(authMiddleware), + ), + ) + + http.Handle("/data", endpoint.Handler()) +} + +type GetArticleRequest struct { + ID string `json:"-"` +} + +func (article *GetArticleRequest) Extract(r *http.Request) error { + article.ID = r.PathValue("id") + + return nil +} + +type GetArticleResponse struct { + ID string `json:"id"` + Title string `json:"title"` +} + +func (article *GetArticleResponse) Write(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/html") + w.WriteHeader(http.StatusOK) + + _, _ = fmt.Fprintf(w, "

%s

", article.Title) + + return nil +} + +func ExampleEndpoint_withCustomResponseWriter() { + getArticle := xapi.EndpointFunc[GetArticleRequest, GetArticleResponse]( + func(ctx context.Context, req *GetArticleRequest) (*GetArticleResponse, error) { + return &GetArticleResponse{ + ID: req.ID, + Title: "Article " + req.ID, + }, nil + }, + ) + + endpoint := xapi.NewEndpoint(getArticle) + + http.Handle("/articles/{id}", endpoint.Handler()) + log.Println("Server starting on :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/xapi/go.mod b/xapi/go.mod new file mode 100644 index 0000000..fa94723 --- /dev/null +++ b/xapi/go.mod @@ -0,0 +1,11 @@ +module github.com/gojekfarm/xtools/xapi + +go 1.25 + +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/xapi/go.sum b/xapi/go.sum new file mode 100644 index 0000000..c4c1710 --- /dev/null +++ b/xapi/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/xapi/middleware.go b/xapi/middleware.go new file mode 100644 index 0000000..cd624e3 --- /dev/null +++ b/xapi/middleware.go @@ -0,0 +1,29 @@ +package xapi + +import "net/http" + +// MiddlewareHandler defines the interface for HTTP middleware. +type MiddlewareHandler interface { + Middleware(next http.Handler) http.Handler +} + +// MiddlewareFunc is a function type that implements MiddlewareHandler. +type MiddlewareFunc func(next http.Handler) http.Handler + +// Middleware implements the MiddlewareHandler interface. +func (m MiddlewareFunc) Middleware(next http.Handler) http.Handler { + return m(next) +} + +// MiddlewareStack represents a stack of middleware handlers. +type MiddlewareStack []MiddlewareHandler + +// Middleware applies all middleware in the stack to the given handler. +// Middleware is applied in reverse order, so the last added middleware +// wraps the innermost handler. +func (m MiddlewareStack) Middleware(next http.Handler) http.Handler { + for i := len(m) - 1; i >= 0; i-- { + next = m[i].Middleware(next) + } + return next +} diff --git a/xapi/options.go b/xapi/options.go new file mode 100644 index 0000000..0cad377 --- /dev/null +++ b/xapi/options.go @@ -0,0 +1,33 @@ +package xapi + +type options struct { + middleware MiddlewareStack + errorHandler ErrorHandler +} + +// EndpointOption defines the interface for endpoint configuration options. +type EndpointOption interface { + apply(o *options) +} + +// endpointOptionFunc is a function type that implements EndpointOption. +type endpointOptionFunc func(o *options) + +// apply implements the EndpointOption interface. +func (f endpointOptionFunc) apply(o *options) { + f(o) +} + +// WithMiddleware returns an EndpointOption that adds middleware to the endpoint. +func WithMiddleware(middlewares ...MiddlewareHandler) EndpointOption { + return endpointOptionFunc(func(o *options) { + o.middleware = append(o.middleware, middlewares...) + }) +} + +// WithErrorHandler returns an EndpointOption that sets a custom error handler for the endpoint. +func WithErrorHandler(errorHandler ErrorHandler) EndpointOption { + return endpointOptionFunc(func(o *options) { + o.errorHandler = errorHandler + }) +}