Skip to content

Commit c639081

Browse files
vegetablechicken233xiyuliujeevatkm
authored
feat: sse client support method and body configurable (#988)
Co-authored-by: xiyuliu <[email protected]> Co-authored-by: Jeevanandam M. <[email protected]>
1 parent 981bb3f commit c639081

File tree

2 files changed

+181
-1
lines changed

2 files changed

+181
-1
lines changed

sse.go

+27-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
var (
2525
defaultSseMaxBufSize = 1 << 15 // 32kb
2626
defaultEventName = "message"
27+
defaultHTTPMethod = MethodGet
2728

2829
headerID = []byte("id:")
2930
headerData = []byte("data:")
@@ -63,7 +64,9 @@ type (
6364
EventSource struct {
6465
lock *sync.RWMutex
6566
url string
67+
method string
6668
header http.Header
69+
body io.Reader
6770
lastEventID string
6871
retryCount int
6972
retryWaitTime time.Duration
@@ -126,6 +129,14 @@ func (es *EventSource) SetURL(url string) *EventSource {
126129
return es
127130
}
128131

132+
// SetMethod method sets a [EventSource] connection HTTP method in the instance
133+
//
134+
// es.SetMethod("POST"), or es.SetMethod(resty.MethodPost)
135+
func (es *EventSource) SetMethod(method string) *EventSource {
136+
es.method = method
137+
return es
138+
}
139+
129140
// SetHeader method sets a header and its value to the [EventSource] instance.
130141
// It overwrites the header value if the key already exists. These headers will be sent in
131142
// the request while establishing a connection to the event source
@@ -139,6 +150,15 @@ func (es *EventSource) SetHeader(header, value string) *EventSource {
139150
return es
140151
}
141152

153+
// SetBody method sets body value to the [EventSource] instance
154+
//
155+
// Example:
156+
// es.SetBody(bytes.NewReader([]byte(`{"test":"put_data"}`)))
157+
func (es *EventSource) SetBody(body io.Reader) *EventSource {
158+
es.body = body
159+
return es
160+
}
161+
142162
// AddHeader method adds a header and its value to the [EventSource] instance.
143163
// If the header key already exists, it appends. These headers will be sent in
144164
// the request while establishing a connection to the event source
@@ -344,6 +364,12 @@ func (es *EventSource) Get() error {
344364
return fmt.Errorf("resty:sse: event source URL is required")
345365
}
346366

367+
if isStringEmpty(es.method) {
368+
// It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here.
369+
// Ensure compatibility, use GET as default http method
370+
es.method = defaultHTTPMethod
371+
}
372+
347373
if len(es.onEvent) == 0 {
348374
return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required")
349375
}
@@ -402,7 +428,7 @@ func (es *EventSource) triggerOnError(err error) {
402428
}
403429

404430
func (es *EventSource) createRequest() (*http.Request, error) {
405-
req, err := http.NewRequest(MethodGet, es.url, nil)
431+
req, err := http.NewRequest(es.method, es.url, es.body)
406432
if err != nil {
407433
return nil, err
408434
}

sse_test.go

+154
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func TestEventSourceSimpleFlow(t *testing.T) {
4646
defer ts.Close()
4747

4848
es.SetURL(ts.URL)
49+
es.SetMethod(MethodPost)
4950
err := es.Get()
5051
assertNil(t, err)
5152
assertEqual(t, counter, messageCounter)
@@ -115,6 +116,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
115116
defer ts.Close()
116117

117118
es.SetURL(ts.URL).
119+
SetMethod(MethodPost).
118120
AddEventListener("user_connect", userConnectFunc, userEvent{}).
119121
AddEventListener("user_message", userMessageFunc, userEvent{})
120122

@@ -354,6 +356,7 @@ func TestEventSourceCoverage(t *testing.T) {
354356
func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource {
355357
es := NewEventSource().
356358
SetURL(url).
359+
SetMethod(MethodGet).
357360
AddHeader("X-Test-Header-1", "test header 1").
358361
SetHeader("X-Test-Header-2", "test header 2").
359362
SetRetryCount(2).
@@ -406,3 +409,154 @@ func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer)
406409
}
407410
})
408411
}
412+
413+
func TestEventSourceWithDifferentMethods(t *testing.T) {
414+
testCases := []struct {
415+
name string
416+
method string
417+
body []byte
418+
}{
419+
{
420+
name: "GET Method",
421+
method: MethodGet,
422+
body: nil,
423+
},
424+
{
425+
name: "POST Method",
426+
method: MethodPost,
427+
body: []byte(`{"test":"post_data"}`),
428+
},
429+
{
430+
name: "PUT Method",
431+
method: MethodPut,
432+
body: []byte(`{"test":"put_data"}`),
433+
},
434+
{
435+
name: "DELETE Method",
436+
method: MethodDelete,
437+
body: nil,
438+
},
439+
{
440+
name: "PATCH Method",
441+
method: MethodPatch,
442+
body: []byte(`{"test":"patch_data"}`),
443+
},
444+
}
445+
446+
for _, tc := range testCases {
447+
t.Run(tc.name, func(t *testing.T) {
448+
messageCounter := 0
449+
messageFunc := func(e any) {
450+
event := e.(*Event)
451+
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
452+
assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
453+
messageCounter++
454+
}
455+
456+
counter := 0
457+
methodVerified := false
458+
bodyVerified := false
459+
460+
es := createEventSource(t, "", messageFunc, nil)
461+
ts := createMethodVerifyingSSETestServer(
462+
t,
463+
10*time.Millisecond,
464+
tc.method,
465+
tc.body,
466+
&methodVerified,
467+
&bodyVerified,
468+
func(w io.Writer) error {
469+
if counter == 20 {
470+
es.Close()
471+
return fmt.Errorf("stop sending events")
472+
}
473+
_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))
474+
counter++
475+
return err
476+
},
477+
)
478+
defer ts.Close()
479+
480+
es.SetURL(ts.URL)
481+
es.SetMethod(tc.method)
482+
483+
// set body
484+
if tc.body != nil {
485+
es.SetBody(bytes.NewBuffer(tc.body))
486+
}
487+
488+
err := es.Get()
489+
assertNil(t, err)
490+
491+
// check the message count
492+
assertEqual(t, counter, messageCounter)
493+
494+
// check if server receive correct method and body
495+
assertEqual(t, true, methodVerified)
496+
if tc.body != nil {
497+
assertEqual(t, true, bodyVerified)
498+
}
499+
})
500+
}
501+
}
502+
503+
// almost like create server before but add verifying method and body
504+
func createMethodVerifyingSSETestServer(
505+
t *testing.T,
506+
ticker time.Duration,
507+
expectedMethod string,
508+
expectedBody []byte,
509+
methodVerified *bool,
510+
bodyVerified *bool,
511+
fn func(io.Writer) error,
512+
) *httptest.Server {
513+
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
514+
// validate method
515+
if r.Method == expectedMethod {
516+
*methodVerified = true
517+
} else {
518+
t.Errorf("Expected method %s, got %s", expectedMethod, r.Method)
519+
}
520+
521+
// validate body
522+
if expectedBody != nil {
523+
body, err := io.ReadAll(r.Body)
524+
if err != nil {
525+
t.Errorf("Failed to read request body: %v", err)
526+
} else if string(body) == string(expectedBody) {
527+
*bodyVerified = true
528+
} else {
529+
t.Errorf("Expected body %s, got %s", string(expectedBody), string(body))
530+
}
531+
}
532+
533+
// same as createSSETestServer
534+
w.Header().Set("Content-Type", "text/event-stream")
535+
w.Header().Set("Cache-Control", "no-cache")
536+
w.Header().Set("Connection", "keep-alive")
537+
w.Header().Set("Access-Control-Allow-Origin", "*")
538+
539+
clientGone := r.Context().Done()
540+
541+
rc := http.NewResponseController(w)
542+
tick := time.NewTicker(ticker)
543+
defer tick.Stop()
544+
545+
for {
546+
select {
547+
case <-clientGone:
548+
t.Log("Client disconnected")
549+
return
550+
case <-tick.C:
551+
if err := fn(w); err != nil {
552+
t.Log(err)
553+
return
554+
}
555+
if err := rc.Flush(); err != nil {
556+
t.Log(err)
557+
return
558+
}
559+
}
560+
}
561+
}))
562+
}

0 commit comments

Comments
 (0)