diff --git a/examples/using-sse/configs/.env b/examples/using-sse/configs/.env new file mode 100644 index 0000000000..52907af83b --- /dev/null +++ b/examples/using-sse/configs/.env @@ -0,0 +1,2 @@ +APP_NAME=using-sse +HTTP_PORT=9000 diff --git a/examples/using-sse/main.go b/examples/using-sse/main.go new file mode 100644 index 0000000000..af1c81f709 --- /dev/null +++ b/examples/using-sse/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "fmt" + "time" + + "gofr.dev/pkg/gofr" +) + +func main() { + app := gofr.New() + + // Stream the current time every second. + // c.Context.Done() fires on both client disconnect and server shutdown. + app.GET("/events", func(c *gofr.Context) (any, error) { + return gofr.SSEResponse(func(stream *gofr.SSEStream) error { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + i := 0 + + for { + select { + case <-c.Context.Done(): + return nil + case t := <-ticker.C: + if err := stream.Send(gofr.SSEEvent{ + ID: fmt.Sprintf("%d", i), + Name: "time", + Data: map[string]string{"time": t.Format(time.RFC3339)}, + }); err != nil { + return err + } + + i++ + } + } + }), nil + }) + + // A countdown that sends 11 events and closes. + app.GET("/countdown", func(c *gofr.Context) (any, error) { + return gofr.SSEResponse(func(stream *gofr.SSEStream) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for i := 10; i >= 0; i-- { + select { + case <-c.Context.Done(): + return nil + case <-ticker.C: + if err := stream.SendEvent("countdown", map[string]int{"remaining": i}); err != nil { + return err + } + } + } + + return stream.SendEvent("done", "Countdown complete!") + }), nil + }) + + app.Run() +} diff --git a/pkg/gofr/handler.go b/pkg/gofr/handler.go index e787d54c09..893ad6149c 100644 --- a/pkg/gofr/handler.go +++ b/pkg/gofr/handler.go @@ -53,7 +53,7 @@ func (el *ErrorLogEntry) PrettyPrint(writer io.Writer) { } func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c := newContext(gofrHTTP.NewResponder(w, r.Method), gofrHTTP.NewRequest(r), h.container) + c := newContext(gofrHTTP.NewResponder(w, r.Method, gofrHTTP.WithLogger(h.container.Logger)), gofrHTTP.NewRequest(r), h.container) traceID := trace.SpanFromContext(r.Context()).SpanContext().TraceID().String() @@ -108,7 +108,13 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp.SetCustomHeaders(w) } - // Handler function completed + // SSE streams are long-lived; bypass request timeout. + // SSEResponse() is instant struct creation, so done always fires before timeout + // for any practical timeout value. If timeout wins, result is nil and this is a no-op. + if _, ok := result.(response.SSE); ok { + c.Context = r.Context() + } + c.responder.Respond(result, err) } diff --git a/pkg/gofr/http/middleware/logger.go b/pkg/gofr/http/middleware/logger.go index 47d0f5b826..e49f9ae72f 100644 --- a/pkg/gofr/http/middleware/logger.go +++ b/pkg/gofr/http/middleware/logger.go @@ -47,6 +47,18 @@ func (w *StatusResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, fmt.Errorf("%w: cannot hijack connection", errHijackNotSupported) } +// Flush delegates to the underlying http.Flusher if supported. +func (w *StatusResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Unwrap returns the underlying ResponseWriter for http.ResponseController. +func (w *StatusResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + // RequestLog represents a log entry for HTTP requests. type RequestLog struct { TraceID string `json:"trace_id,omitempty"` diff --git a/pkg/gofr/http/middleware/logger_test.go b/pkg/gofr/http/middleware/logger_test.go index 716affa863..d064f8d071 100644 --- a/pkg/gofr/http/middleware/logger_test.go +++ b/pkg/gofr/http/middleware/logger_test.go @@ -320,3 +320,43 @@ type mockAddr struct{} func (*mockAddr) Network() string { return "tcp" } func (*mockAddr) String() string { return "127.0.0.1:8080" } + +func Test_StatusResponseWriter_Flush_Supported(t *testing.T) { + rr := httptest.NewRecorder() + srw := &StatusResponseWriter{ResponseWriter: rr} + + // httptest.ResponseRecorder implements http.Flusher. + assert.NotPanics(t, func() { + srw.Flush() + }) + + assert.True(t, rr.Flushed, "expected recorder to be flushed") +} + +func Test_StatusResponseWriter_Flush_NotSupported(t *testing.T) { + writer := &nonFlushableWriter{header: http.Header{}} + srw := &StatusResponseWriter{ResponseWriter: writer} + + // Should not panic even if the underlying writer doesn't support Flusher. + assert.NotPanics(t, func() { + srw.Flush() + }) +} + +func Test_StatusResponseWriter_Unwrap(t *testing.T) { + rr := httptest.NewRecorder() + srw := &StatusResponseWriter{ResponseWriter: rr} + + unwrapped := srw.Unwrap() + + assert.Equal(t, rr, unwrapped, "expected Unwrap to return the underlying ResponseWriter") +} + +// nonFlushableWriter is a ResponseWriter that does NOT implement http.Flusher. +type nonFlushableWriter struct { + header http.Header +} + +func (n *nonFlushableWriter) Header() http.Header { return n.header } +func (*nonFlushableWriter) Write([]byte) (int, error) { return 0, nil } +func (*nonFlushableWriter) WriteHeader(int) {} diff --git a/pkg/gofr/http/responder.go b/pkg/gofr/http/responder.go index 435f283394..2a87ef8492 100644 --- a/pkg/gofr/http/responder.go +++ b/pkg/gofr/http/responder.go @@ -13,15 +13,35 @@ var ( errEmptyResponse = errors.New("internal server error") ) +// sseLogger is a minimal logging interface used only for SSE error reporting. +type sseLogger interface { + Debugf(format string, args ...any) +} + +// ResponderOption configures optional Responder behavior. +type ResponderOption func(*Responder) + +// WithLogger attaches a logger to the Responder for debug-level SSE error logging. +func WithLogger(l sseLogger) ResponderOption { + return func(r *Responder) { r.logger = l } +} + // NewResponder creates a new Responder instance from the given http.ResponseWriter. -func NewResponder(w http.ResponseWriter, method string) *Responder { - return &Responder{w: w, method: method} +func NewResponder(w http.ResponseWriter, method string, opts ...ResponderOption) *Responder { + r := &Responder{w: w, method: method} + + for _, o := range opts { + o(r) + } + + return r } // Responder encapsulates an http.ResponseWriter and is responsible for crafting structured responses. type Responder struct { w http.ResponseWriter method string + logger sseLogger } // Respond sends a response with the given data and handles potential errors, setting appropriate @@ -75,6 +95,10 @@ func (r Responder) handleSpecialResponseTypes(data any, err error) bool { statusCode := r.getStatusCodeForSpecialResponse(data, err) switch v := data.(type) { + case resTypes.SSE: + r.handleSSEResponse(v) + return true + case resTypes.File: r.w.Header().Set("Content-Type", v.ContentType) r.w.WriteHeader(statusCode) @@ -276,3 +300,32 @@ func isNil(i any) bool { return v.Kind() == reflect.Ptr && v.IsNil() } + +// handleSSEResponse handles Server-Sent Events responses. +// +// TODO: SSE connections block for the full connection lifetime, causing the logging middleware +// and response histogram to record the entire duration. Consider labeling SSE in the histogram. +func (r Responder) handleSSEResponse(sse resTypes.SSE) { + if sse.Callback == nil { + if r.logger != nil { + r.logger.Debugf("SSE response has nil callback") + } + + return + } + + r.w.Header().Set("Content-Type", "text/event-stream") + r.w.Header().Set("Cache-Control", "no-cache") + r.w.Header().Set("Connection", "keep-alive") + r.w.Header().Set("X-Accel-Buffering", "no") + r.w.WriteHeader(http.StatusOK) + + rc := http.NewResponseController(r.w) + _ = rc.Flush() + + if err := sse.Callback(r.w, rc); err != nil { + if r.logger != nil { + r.logger.Debugf("SSE stream error: %v", err) + } + } +} diff --git a/pkg/gofr/http/responder_test.go b/pkg/gofr/http/responder_test.go index b81bc76ed8..cf5980a5eb 100644 --- a/pkg/gofr/http/responder_test.go +++ b/pkg/gofr/http/responder_test.go @@ -16,7 +16,10 @@ import ( resTypes "gofr.dev/pkg/gofr/http/response" ) -var errTest = fmt.Errorf("internal server error") +var ( + errTest = fmt.Errorf("internal server error") // Existing + errSSECallbackTest = fmt.Errorf("test error") +) func TestResponder(t *testing.T) { tests := []struct { @@ -604,3 +607,38 @@ func TestResponder_ValidEncodableData(t *testing.T) { assert.NotEmpty(t, body.String(), "TEST[%d] Failed: %s", i, tc.desc) } } + +func TestResponder_SSE_NilCallback(t *testing.T) { + w := httptest.NewRecorder() + r := NewResponder(w, http.MethodGet) + + // Zero-value SSE has nil Callback — should not panic. + r.Respond(resTypes.SSE{}, nil) + + // No SSE headers should be set since callback is nil. + assert.Empty(t, w.Header().Get("Content-Type")) +} + +// mockSSELogger captures Debugf calls for test assertions. +type mockSSELogger struct { + messages []string +} + +func (m *mockSSELogger) Debugf(format string, args ...any) { + m.messages = append(m.messages, fmt.Sprintf(format, args...)) +} + +func TestResponder_SSE_CallbackError(t *testing.T) { + w := httptest.NewRecorder() + logger := &mockSSELogger{} + r := NewResponder(w, http.MethodGet, WithLogger(logger)) + + r.Respond(resTypes.SSE{ + Callback: func(_ http.ResponseWriter, _ *http.ResponseController) error { + return errSSECallbackTest + }, + }, nil) + + require.Len(t, logger.messages, 1) + assert.Contains(t, logger.messages[0], errSSECallbackTest.Error()) +} diff --git a/pkg/gofr/http/response/sse.go b/pkg/gofr/http/response/sse.go new file mode 100644 index 0000000000..2c07ae72a2 --- /dev/null +++ b/pkg/gofr/http/response/sse.go @@ -0,0 +1,12 @@ +package response + +import "net/http" + +// SSECallback is the function signature for SSE streaming callbacks. +type SSECallback func(w http.ResponseWriter, rc *http.ResponseController) error + +// SSE represents a Server-Sent Events response. +// Return this from a handler to stream events to the client. +type SSE struct { + Callback SSECallback +} diff --git a/pkg/gofr/http_server.go b/pkg/gofr/http_server.go index 1597f48235..aa83f6263a 100644 --- a/pkg/gofr/http_server.go +++ b/pkg/gofr/http_server.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "sync" "time" "gofr.dev/pkg/gofr/container" @@ -19,6 +20,7 @@ type httpServer struct { port int ws *websocket.Manager srv *http.Server + mu sync.Mutex certFile string keyFile string staticFiles map[string]string @@ -58,18 +60,23 @@ func (s *httpServer) run(c *container.Container) { middleware.WSHandlerUpgrade(c, s.ws), ) + s.mu.Lock() if s.srv != nil { + s.mu.Unlock() c.Logf("Server already running on port: %d", s.port) + return } c.Logf("Starting server on port: %d", s.port) - s.srv = &http.Server{ + srv := &http.Server{ Addr: fmt.Sprintf(":%d", s.port), Handler: s.router, ReadHeaderTimeout: 5 * time.Second, } + s.srv = srv + s.mu.Unlock() // If both certFile and keyFile are provided, validate and run HTTPS server if s.certFile != "" && s.keyFile != "" { @@ -79,7 +86,7 @@ func (s *httpServer) run(c *container.Container) { } // Start HTTPS server with TLS - if err := s.srv.ListenAndServeTLS(s.certFile, s.keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := srv.ListenAndServeTLS(s.certFile, s.keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) { c.Errorf("error while listening to https server, err: %v", err) } @@ -87,20 +94,24 @@ func (s *httpServer) run(c *container.Container) { } // If no certFile/keyFile is provided, run the HTTP server - if err := s.srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { c.Errorf("error while listening to http server, err: %v", err) } } func (s *httpServer) Shutdown(ctx context.Context) error { - if s.srv == nil { + s.mu.Lock() + srv := s.srv + s.mu.Unlock() + + if srv == nil { return nil } return ShutdownWithContext(ctx, func(ctx context.Context) error { - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) }, func() error { - if err := s.srv.Close(); err != nil { + if err := srv.Close(); err != nil { return err } diff --git a/pkg/gofr/metrics_server.go b/pkg/gofr/metrics_server.go index 97a9b48770..0670b3c16b 100644 --- a/pkg/gofr/metrics_server.go +++ b/pkg/gofr/metrics_server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "sync" "time" "gofr.dev/pkg/gofr/container" @@ -14,6 +15,7 @@ import ( type metricServer struct { port int srv *http.Server + mu sync.Mutex } func newMetricServer(port int) *metricServer { @@ -24,13 +26,17 @@ func (m *metricServer) Run(c *container.Container) { if m != nil { c.Logf("Starting metrics server on port: %d", m.port) - m.srv = &http.Server{ + srv := &http.Server{ Addr: fmt.Sprintf(":%d", m.port), Handler: metrics.GetHandler(c.Metrics()), ReadHeaderTimeout: 5 * time.Second, } - err := m.srv.ListenAndServe() + m.mu.Lock() + m.srv = srv + m.mu.Unlock() + + err := srv.ListenAndServe() if !errors.Is(err, http.ErrServerClosed) { c.Errorf("error while listening to metrics server, err: %v", err) @@ -39,11 +45,15 @@ func (m *metricServer) Run(c *container.Container) { } func (m *metricServer) Shutdown(ctx context.Context) error { - if m.srv == nil { + m.mu.Lock() + srv := m.srv + m.mu.Unlock() + + if srv == nil { return nil } return ShutdownWithContext(ctx, func(ctx context.Context) error { - return m.srv.Shutdown(ctx) + return srv.Shutdown(ctx) }, nil) } diff --git a/pkg/gofr/sse.go b/pkg/gofr/sse.go new file mode 100644 index 0000000000..9807d5b285 --- /dev/null +++ b/pkg/gofr/sse.go @@ -0,0 +1,191 @@ +package gofr + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "gofr.dev/pkg/gofr/http/response" +) + +// defaultHeartbeatInterval is the interval between automatic heartbeat comments. +const defaultHeartbeatInterval = 15 * time.Second + +// SSEEvent represents a single Server-Sent Event. +type SSEEvent struct { + Name string // Event type (event: field) + Data any // Event data (data: field) - strings pass through, others are JSON-encoded + ID string // Event ID (id: field) + Retry int // Reconnection time in milliseconds (retry: field) +} + +// SSEStream writes Server-Sent Events directly to the HTTP response. +// It is safe for concurrent use; a mutex serializes all writes. +type SSEStream struct { + w http.ResponseWriter + rc *http.ResponseController + mu sync.Mutex +} + +// SSEFunc is the callback signature for SSE handlers. +// The function receives an SSEStream to write events. +type SSEFunc func(stream *SSEStream) error + +// SSEResponse creates an SSE response that can be returned from a handler. +// A heartbeat comment is automatically sent every 15s to keep the connection +// alive through proxies with idle timeouts. +// +// SSE handlers should return near-instantly (SSEResponse is just struct creation). +// The long-lived streaming callback runs later inside Respond(). +// +// Example: +// +// app.GET("/events", func(c *gofr.Context) (any, error) { +// return gofr.SSEResponse(func(stream *gofr.SSEStream) error { +// for i := 0; i < 10; i++ { +// if err := stream.SendEvent("counter", i); err != nil { +// return err +// } +// time.Sleep(time.Second) +// } +// return nil +// }), nil +// }) +func SSEResponse(callback SSEFunc) response.SSE { + return response.SSE{ + Callback: func(w http.ResponseWriter, rc *http.ResponseController) error { + stream := &SSEStream{w: w, rc: rc} + + done := make(chan struct{}) + go stream.runHeartbeat(done, defaultHeartbeatInterval) + + defer close(done) + + return callback(stream) + }, + } +} + +// Send writes a formatted SSE event to the stream and flushes. +func (s *SSEStream) Send(event SSEEvent) error { + return s.writeEvent(event) +} + +// SendData is shorthand for Send(SSEEvent{Data: data}). +func (s *SSEStream) SendData(data any) error { + return s.writeEvent(SSEEvent{Data: data}) +} + +// SendEvent is shorthand for Send(SSEEvent{Name: name, Data: data}). +func (s *SSEStream) SendEvent(name string, data any) error { + return s.writeEvent(SSEEvent{Name: name, Data: data}) +} + +// SendComment writes an SSE comment (: prefix) to the stream. +// Comments are often used as keep-alive heartbeats. +func (s *SSEStream) SendComment(text string) error { + var sb strings.Builder + + for _, line := range strings.Split(text, "\n") { + fmt.Fprintf(&sb, ": %s\n", line) + } + + sb.WriteString("\n") + + s.mu.Lock() + defer s.mu.Unlock() + + if _, err := fmt.Fprint(s.w, sb.String()); err != nil { + return err + } + + return s.rc.Flush() +} + +// writeEvent formats, writes, and flushes a single SSE event. +func (s *SSEStream) writeEvent(event SSEEvent) error { + raw, err := formatEvent(event) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, err := fmt.Fprint(s.w, raw); err != nil { + return err + } + + return s.rc.Flush() +} + +// runHeartbeat sends periodic comment frames to keep the connection alive +// through proxies that kill idle connections. Stops when done is closed. +func (s *SSEStream) runHeartbeat(done <-chan struct{}, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-done: + return + case <-ticker.C: + if err := s.SendComment("heartbeat"); err != nil { + return + } + } + } +} + +// formatEvent builds the wire-format string for one SSE event. +func formatEvent(event SSEEvent) (string, error) { + var sb strings.Builder + + if event.ID != "" { + fmt.Fprintf(&sb, "id: %s\n", event.ID) + } + + if event.Name != "" { + fmt.Fprintf(&sb, "event: %s\n", event.Name) + } + + if event.Retry > 0 { + fmt.Fprintf(&sb, "retry: %d\n", event.Retry) + } + + dataStr, err := formatSSEData(event.Data) + if err != nil { + return "", err + } + + for _, line := range strings.Split(dataStr, "\n") { + fmt.Fprintf(&sb, "data: %s\n", line) + } + + sb.WriteString("\n") + + return sb.String(), nil +} + +// formatSSEData converts data to a string for SSE. +// Strings and []byte pass through; everything else is JSON-encoded. +func formatSSEData(data any) (string, error) { + switch v := data.(type) { + case string: + return v, nil + case []byte: + return string(v), nil + case nil: + return "", nil + default: + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("failed to marshal SSE data: %w", err) + } + + return string(b), nil + } +} diff --git a/pkg/gofr/sse_test.go b/pkg/gofr/sse_test.go new file mode 100644 index 0000000000..40322441e5 --- /dev/null +++ b/pkg/gofr/sse_test.go @@ -0,0 +1,473 @@ +package gofr + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gofr.dev/pkg/gofr/testutil" +) + +var errSSEStreamFailed = fmt.Errorf("stream failed") + +func Test_formatSSEData(t *testing.T) { + tests := []struct { + name string + input any + expected string + }{ + {name: "string data", input: "hello world", expected: "hello world"}, + {name: "byte slice data", input: []byte("raw bytes"), expected: "raw bytes"}, + {name: "nil data", input: nil, expected: ""}, + {name: "struct data", input: struct { + Name string `json:"name"` + Age int `json:"age"` + }{Name: "GoFr", Age: 1}, expected: `{"name":"GoFr","age":1}`}, + {name: "map data", input: map[string]string{"key": "value"}, expected: `{"key":"value"}`}, + {name: "integer data", input: 42, expected: "42"}, + {name: "boolean data", input: true, expected: "true"}, + {name: "slice data", input: []int{1, 2, 3}, expected: "[1,2,3]"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := formatSSEData(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } +} + +func Test_formatSSEData_UnsupportedType(t *testing.T) { + ch := make(chan int) + _, err := formatSSEData(ch) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal SSE data") +} + +func Test_formatEvent(t *testing.T) { + tests := []struct { + name string + event SSEEvent + expected string + }{ + { + name: "data only", + event: SSEEvent{Data: "hello"}, + expected: "data: hello\n\n", + }, + { + name: "named event", + event: SSEEvent{Name: "update", Data: "new data"}, + expected: "event: update\ndata: new data\n\n", + }, + { + name: "event with ID", + event: SSEEvent{ID: "123", Data: "test"}, + expected: "id: 123\ndata: test\n\n", + }, + { + name: "event with retry", + event: SSEEvent{Retry: 5000, Data: "retry test"}, + expected: "retry: 5000\ndata: retry test\n\n", + }, + { + name: "all fields", + event: SSEEvent{ID: "42", Name: "message", Retry: 3000, Data: "full event"}, + expected: "id: 42\nevent: message\nretry: 3000\ndata: full event\n\n", + }, + { + name: "JSON data", + event: SSEEvent{Data: map[string]string{"key": "value"}}, + expected: "data: {\"key\":\"value\"}\n\n", + }, + { + name: "multiline data", + event: SSEEvent{Data: "line1\nline2\nline3"}, + expected: "data: line1\ndata: line2\ndata: line3\n\n", + }, + { + name: "nil data", + event: SSEEvent{Data: nil}, + expected: "data: \n\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := formatEvent(tc.event) + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } +} + +// flushRecorder is a ResponseRecorder that implements http.Flusher. +type flushRecorder struct { + *httptest.ResponseRecorder + flushed bool +} + +func newFlushRecorder() *flushRecorder { + return &flushRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (f *flushRecorder) Flush() { + f.flushed = true +} + +func (f *flushRecorder) Unwrap() http.ResponseWriter { + return f.ResponseRecorder +} + +func TestSSEStream_Send(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.Send(SSEEvent{Name: "test", Data: "hello"}) + require.NoError(t, err) + + assert.Equal(t, "event: test\ndata: hello\n\n", w.Body.String()) + assert.True(t, w.flushed) +} + +func TestSSEStream_SendData(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendData("simple") + require.NoError(t, err) + + assert.Equal(t, "data: simple\n\n", w.Body.String()) +} + +func TestSSEStream_SendEvent(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendEvent("notification", map[string]int{"count": 5}) + require.NoError(t, err) + + assert.Equal(t, "event: notification\ndata: {\"count\":5}\n\n", w.Body.String()) +} + +func TestSSEStream_SendComment(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendComment("keep-alive") + require.NoError(t, err) + + assert.Equal(t, ": keep-alive\n\n", w.Body.String()) +} + +func TestSSEStream_SendComment_Multiline(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendComment("line1\nline2") + require.NoError(t, err) + + assert.Equal(t, ": line1\n: line2\n\n", w.Body.String()) +} + +func TestSSEStream_Send_JSONStruct(t *testing.T) { + type Notification struct { + Title string `json:"title"` + Message string `json:"message"` + } + + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.Send(SSEEvent{ + Name: "notification", + ID: "1", + Data: Notification{Title: "Hello", Message: "World"}, + }) + + require.NoError(t, err) + assert.Equal(t, "id: 1\nevent: notification\ndata: {\"title\":\"Hello\",\"message\":\"World\"}\n\n", w.Body.String()) +} + +// safeFlushRecorder is a thread-safe ResponseWriter for concurrent tests. +type safeFlushRecorder struct { + mu sync.Mutex + buf []byte + headers http.Header +} + +func newSafeFlushRecorder() *safeFlushRecorder { + return &safeFlushRecorder{headers: http.Header{}} +} + +func (s *safeFlushRecorder) Header() http.Header { return s.headers } + +func (s *safeFlushRecorder) Write(b []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + s.buf = append(s.buf, b...) + + return len(b), nil +} + +func (*safeFlushRecorder) WriteHeader(int) {} + +func (*safeFlushRecorder) Flush() {} + +func (s *safeFlushRecorder) Unwrap() http.ResponseWriter { return s } + +func (s *safeFlushRecorder) String() string { + s.mu.Lock() + defer s.mu.Unlock() + + return string(s.buf) +} + +func TestSSEStream_ConcurrentSend(t *testing.T) { + w := newSafeFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + count := 50 + + var wg sync.WaitGroup + + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + + _ = stream.SendData(i) + }(i) + } + + wg.Wait() + + body := w.String() + + dataCount := strings.Count(body, "data:") + + assert.Equal(t, count, dataCount, "all events should be written with mutex protection") +} + +func TestSSEStream_StreamingLoop(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + for i := 0; i < 5; i++ { + err := stream.Send(SSEEvent{ + ID: string(rune('0' + i)), + Name: "tick", + Data: map[string]int{"count": i}, + }) + require.NoError(t, err) + } + + body := w.Body.String() + + for i := 0; i < 5; i++ { + expected, _ := json.Marshal(map[string]int{"count": i}) + assert.Contains(t, body, "data: "+string(expected)+"\n") + } +} + +// nonFlushableWriter is a ResponseWriter that does NOT implement http.Flusher. +type nonFlushableWriter struct { + header http.Header +} + +func (n *nonFlushableWriter) Header() http.Header { return n.header } +func (*nonFlushableWriter) Write([]byte) (int, error) { return 0, nil } +func (*nonFlushableWriter) WriteHeader(int) {} + +func TestSSEStream_NoFlusher(t *testing.T) { + w := &nonFlushableWriter{header: http.Header{}} + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendData("test") + assert.Error(t, err) +} + +func TestSSEStream_Heartbeat(t *testing.T) { + w := newSafeFlushRecorder() + rc := http.NewResponseController(w) + stream := &SSEStream{w: w, rc: rc} + + done := make(chan struct{}) + + go stream.runHeartbeat(done, 50*time.Millisecond) + + time.Sleep(150 * time.Millisecond) + close(done) + + time.Sleep(10 * time.Millisecond) + + body := w.String() + heartbeatCount := strings.Count(body, ": heartbeat") + + assert.GreaterOrEqual(t, heartbeatCount, 2, "should have at least 2 heartbeat comments") +} + +func TestSSEResponse_Integration(t *testing.T) { + configs := testutil.NewServerConfigs(t) + + app := New() + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = app.Shutdown(ctx) + }) + + app.GET("/events", func(_ *Context) (any, error) { + return SSEResponse(func(stream *SSEStream) error { + for i := 0; i < 3; i++ { + if err := stream.SendData(map[string]int{"count": i}); err != nil { + return err + } + } + + return nil + }), nil + }) + + go app.Run() + + // Wait for server readiness via alive endpoint. + for i := 0; i < 50; i++ { + req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, configs.HTTPHost+"/.well-known/alive", http.NoBody) + if reqErr != nil { + t.Fatalf("create request: %v", reqErr) + } + + probe, doErr := http.DefaultClient.Do(req) + if doErr == nil { + probe.Body.Close() + + break + } + + time.Sleep(100 * time.Millisecond) + } + + req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, configs.HTTPHost+"/events", http.NoBody) + require.NoError(t, reqErr) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + + assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, readErr := io.ReadAll(resp.Body) + require.NoError(t, readErr) + + bodyStr := string(body) + + assert.Contains(t, bodyStr, "data: {\"count\":0}\n") + assert.Contains(t, bodyStr, "data: {\"count\":1}\n") + assert.Contains(t, bodyStr, "data: {\"count\":2}\n") +} + +func TestSSEResponse_Integration_ClientDisconnect(t *testing.T) { + configs := testutil.NewServerConfigs(t) + + app := New() + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = app.Shutdown(ctx) + }) + + handlerExited := make(chan struct{}) + + app.GET("/stream", func(c *Context) (any, error) { + return SSEResponse(func(stream *SSEStream) error { + defer close(handlerExited) + + _ = stream.SendData("connected") + + <-c.Context.Done() + + return nil + }), nil + }) + + go app.Run() + + // Wait for server readiness via alive endpoint. + for i := 0; i < 50; i++ { + req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, configs.HTTPHost+"/.well-known/alive", http.NoBody) + if reqErr != nil { + t.Fatalf("create request: %v", reqErr) + } + + probe, doErr := http.DefaultClient.Do(req) + if doErr == nil { + probe.Body.Close() + + break + } + + time.Sleep(100 * time.Millisecond) + } + + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, configs.HTTPHost+"/stream", http.NoBody) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + buf := make([]byte, 512) + n, _ := resp.Body.Read(buf) + assert.Contains(t, string(buf[:n]), "data: connected") + + cancel() + resp.Body.Close() + + select { + case <-handlerExited: + case <-time.After(5 * time.Second): + t.Fatal("handler did not exit after client disconnect") + } +} + +func TestSSEResponse_NilCallback(t *testing.T) { + w := newFlushRecorder() + stream := &SSEStream{w: w, rc: http.NewResponseController(w)} + + err := stream.SendData("direct write") + require.NoError(t, err) + + assert.Contains(t, w.Body.String(), "data: direct write") +} + +func TestSSEResponse_CallbackError(t *testing.T) { + w := newFlushRecorder() + rc := http.NewResponseController(w) + + sse := SSEResponse(func(_ *SSEStream) error { + return errSSEStreamFailed + }) + + // Call the callback directly, simulating what the responder does. + err := sse.Callback(w, rc) + assert.ErrorIs(t, err, errSSEStreamFailed) +}