Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/using-sse/configs/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
APP_NAME=using-sse
HTTP_PORT=9000
61 changes: 61 additions & 0 deletions examples/using-sse/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package main

import (
"fmt"
"time"

"gofr.dev/pkg/gofr"
)

func main() {
app := gofr.New()

// Stream the current time every second.
app.GET("/events", func(c *gofr.Context) (any, error) {
return gofr.SSEResponse(func(stream *gofr.SSEStream) error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

i := 0

for {
select {
case <-c.Context.Done():
return nil
case t := <-ticker.C:
if err := stream.Send(gofr.SSEEvent{
ID: fmt.Sprintf("%d", i),
Name: "time",
Data: map[string]string{"time": t.Format(time.RFC3339)},
}); err != nil {
return err
}

i++
}
}
}), nil
})

// A countdown that sends 11 events and closes.
app.GET("/countdown", func(c *gofr.Context) (any, error) {
return gofr.SSEResponse(func(stream *gofr.SSEStream) error {
for i := 10; i >= 0; i-- {
select {
case <-c.Context.Done():
return nil
default:
if err := stream.SendEvent("countdown", map[string]int{"remaining": i}); err != nil {
return err
}

time.Sleep(500 * time.Millisecond)
}
}

return stream.SendEvent("done", "Countdown complete!")
}), nil
})

app.Run()
}
12 changes: 12 additions & 0 deletions pkg/gofr/http/middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ func (w *StatusResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, fmt.Errorf("%w: cannot hijack connection", errHijackNotSupported)
}

// Flush delegates to the underlying http.Flusher if supported.
func (w *StatusResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

// Unwrap returns the underlying ResponseWriter for http.ResponseController.
func (w *StatusResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

// RequestLog represents a log entry for HTTP requests.
type RequestLog struct {
TraceID string `json:"trace_id,omitempty"`
Expand Down
40 changes: 40 additions & 0 deletions pkg/gofr/http/middleware/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,43 @@ type mockAddr struct{}

func (*mockAddr) Network() string { return "tcp" }
func (*mockAddr) String() string { return "127.0.0.1:8080" }

func Test_StatusResponseWriter_Flush_Supported(t *testing.T) {
rr := httptest.NewRecorder()
srw := &StatusResponseWriter{ResponseWriter: rr}

// httptest.ResponseRecorder implements http.Flusher.
assert.NotPanics(t, func() {
srw.Flush()
})

assert.True(t, rr.Flushed, "expected recorder to be flushed")
}

func Test_StatusResponseWriter_Flush_NotSupported(t *testing.T) {
writer := &nonFlushableWriter{header: http.Header{}}
srw := &StatusResponseWriter{ResponseWriter: writer}

// Should not panic even if the underlying writer doesn't support Flusher.
assert.NotPanics(t, func() {
srw.Flush()
})
}

func Test_StatusResponseWriter_Unwrap(t *testing.T) {
rr := httptest.NewRecorder()
srw := &StatusResponseWriter{ResponseWriter: rr}

unwrapped := srw.Unwrap()

assert.Equal(t, rr, unwrapped, "expected Unwrap to return the underlying ResponseWriter")
}

// nonFlushableWriter is a ResponseWriter that does NOT implement http.Flusher.
type nonFlushableWriter struct {
header http.Header
}

func (n *nonFlushableWriter) Header() http.Header { return n.header }
func (*nonFlushableWriter) Write([]byte) (int, error) { return 0, nil }
func (*nonFlushableWriter) WriteHeader(int) {}
22 changes: 22 additions & 0 deletions pkg/gofr/http/responder.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func (r Responder) handleSpecialResponseTypes(data any, err error) bool {
statusCode := r.getStatusCodeForSpecialResponse(data, err)

switch v := data.(type) {
case resTypes.SSE:
r.handleSSEResponse(v)
return true

case resTypes.File:
r.w.Header().Set("Content-Type", v.ContentType)
r.w.WriteHeader(statusCode)
Expand Down Expand Up @@ -276,3 +280,21 @@ func isNil(i any) bool {

return v.Kind() == reflect.Ptr && v.IsNil()
}

// handleSSEResponse handles Server-Sent Events responses.
// It sets appropriate headers, creates the stream, and calls the user's stream function.
func (r Responder) handleSSEResponse(sse resTypes.SSE) {
// Set SSE headers
r.w.Header().Set("Content-Type", "text/event-stream")
r.w.Header().Set("Cache-Control", "no-cache")
r.w.Header().Set("Connection", "keep-alive")
r.w.Header().Set("X-Accel-Buffering", "no")
r.w.WriteHeader(http.StatusOK)

// Initial flush to establish connection
rc := http.NewResponseController(r.w)
_ = rc.Flush()

// Intercept and stream events using the handler's callback
_ = sse.Stream(r.w)
}
12 changes: 12 additions & 0 deletions pkg/gofr/http/response/sse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package response

import (
"net/http"
)

// SSE represents a Server-Sent Events response.
// Return this from a handler to stream events to the client.
type SSE struct {
// Stream uses the provided ResponseWriter to send Server-Sent Events.
Stream func(w http.ResponseWriter) error
}
155 changes: 155 additions & 0 deletions pkg/gofr/sse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package gofr

import (
"encoding/json"
"fmt"
"net/http"
"strings"

"gofr.dev/pkg/gofr/http/response"
)

// SSEEvent represents a single Server-Sent Event.
type SSEEvent struct {
Name string // Event type (event: field)
Data any // Event data (data: field) - strings pass through, others are JSON-encoded
ID string // Event ID (id: field)
Retry int // Reconnection time in milliseconds (retry: field)
}

// SSEStream writes Server-Sent Events directly to the HTTP response.
// It wraps an http.ResponseWriter and flushes after each event.
type SSEStream struct {
w http.ResponseWriter
rc *http.ResponseController
}

// SSEFunc is the callback signature for SSE handlers.
// The function receives an SSEStream to write events.
type SSEFunc func(stream *SSEStream) error

// SSEResponse creates an SSE response that can be returned from a handler.
// The callback function is called by the Responder to produce SSE events.
//
// Example:
//
// app.GET("/events", func(c *gofr.Context) (any, error) {
// return gofr.SSEResponse(func(stream *gofr.SSEStream) error {
// for i := 0; i < 10; i++ {
// if err := stream.SendEvent("counter", i); err != nil {
// return err
// }
// time.Sleep(time.Second)
// }
// return nil
// }), nil
// })
func SSEResponse(callback SSEFunc) response.SSE {
return response.SSE{
Stream: func(w http.ResponseWriter) error {
stream := &SSEStream{
w: w,
rc: http.NewResponseController(w),
}
return callback(stream)
},
}
}

// Send writes a formatted SSE event to the stream and flushes.
func (s *SSEStream) Send(event any) error {
sseEvent, ok := event.(SSEEvent)
if !ok {
// If not an SSEEvent, treat as data-only event
return s.SendData(event)
}

raw, err := formatEvent(sseEvent)
if err != nil {
return err
}

if _, err := fmt.Fprint(s.w, raw); err != nil {
return err
}

return s.rc.Flush()
}

// SendData is shorthand for Send(SSEEvent{Data: data}).
func (s *SSEStream) SendData(data any) error {
return s.Send(SSEEvent{Data: data})
}

// SendEvent is shorthand for Send(SSEEvent{Name: name, Data: data}).
func (s *SSEStream) SendEvent(name string, data any) error {
return s.Send(SSEEvent{Name: name, Data: data})
}

// SendComment writes an SSE comment (: prefix) to the stream.
// Comments are often used as keep-alive heartbeats.
func (s *SSEStream) SendComment(text string) error {
var sb strings.Builder

for _, line := range strings.Split(text, "\n") {
fmt.Fprintf(&sb, ": %s\n", line)
}

sb.WriteString("\n")

if _, err := fmt.Fprint(s.w, sb.String()); err != nil {
return err
}

return s.rc.Flush()
}

// formatEvent builds the wire-format string for one SSE event.
func formatEvent(event SSEEvent) (string, error) {
var sb strings.Builder

if event.ID != "" {
fmt.Fprintf(&sb, "id: %s\n", event.ID)
}

if event.Name != "" {
fmt.Fprintf(&sb, "event: %s\n", event.Name)
}

if event.Retry > 0 {
fmt.Fprintf(&sb, "retry: %d\n", event.Retry)
}

dataStr, err := formatSSEData(event.Data)
if err != nil {
return "", err
}

for _, line := range strings.Split(dataStr, "\n") {
fmt.Fprintf(&sb, "data: %s\n", line)
}

sb.WriteString("\n")

return sb.String(), nil
}

// formatSSEData converts data to a string for SSE.
// Strings and []byte pass through; everything else is JSON-encoded.
func formatSSEData(data any) (string, error) {
switch v := data.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
case nil:
return "", nil
default:
b, err := json.Marshal(v)
if err != nil {
return "", fmt.Errorf("failed to marshal SSE data: %w", err)
}

return string(b), nil
}
}
Loading
Loading