|
1 | 1 | package httpx
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "context"
|
5 | 6 | "errors"
|
6 | 7 | "fmt"
|
| 8 | + "io" |
7 | 9 | "net/http"
|
| 10 | + "net/http/httptest" |
8 | 11 | "strings"
|
9 | 12 | "testing"
|
10 | 13 |
|
@@ -239,6 +242,60 @@ func TestWriteJsonMarshalFailed(t *testing.T) {
|
239 | 242 | assert.Equal(t, http.StatusInternalServerError, w.code)
|
240 | 243 | }
|
241 | 244 |
|
| 245 | +func TestStream(t *testing.T) { |
| 246 | + t.Run("regular case", func(t *testing.T) { |
| 247 | + channel := make(chan string) |
| 248 | + go func() { |
| 249 | + defer close(channel) |
| 250 | + for index := 0; index < 5; index++ { |
| 251 | + channel <- fmt.Sprintf("%d", index) |
| 252 | + } |
| 253 | + }() |
| 254 | + |
| 255 | + w := httptest.NewRecorder() |
| 256 | + Stream(context.Background(), w, func(w io.Writer) bool { |
| 257 | + output, ok := <-channel |
| 258 | + if !ok { |
| 259 | + return false |
| 260 | + } |
| 261 | + |
| 262 | + outputBytes := bytes.NewBufferString(output) |
| 263 | + _, err := w.Write(append(outputBytes.Bytes(), []byte("\n")...)) |
| 264 | + return err == nil |
| 265 | + }) |
| 266 | + |
| 267 | + assert.Equal(t, http.StatusOK, w.Code) |
| 268 | + assert.Equal(t, "0\n1\n2\n3\n4\n", w.Body.String()) |
| 269 | + }) |
| 270 | + |
| 271 | + t.Run("context done", func(t *testing.T) { |
| 272 | + channel := make(chan string) |
| 273 | + go func() { |
| 274 | + defer close(channel) |
| 275 | + for index := 0; index < 5; index++ { |
| 276 | + channel <- fmt.Sprintf("num: %d", index) |
| 277 | + } |
| 278 | + }() |
| 279 | + |
| 280 | + w := httptest.NewRecorder() |
| 281 | + ctx, cancel := context.WithCancel(context.Background()) |
| 282 | + cancel() |
| 283 | + Stream(ctx, w, func(w io.Writer) bool { |
| 284 | + output, ok := <-channel |
| 285 | + if !ok { |
| 286 | + return false |
| 287 | + } |
| 288 | + |
| 289 | + outputBytes := bytes.NewBufferString(output) |
| 290 | + _, err := w.Write(append(outputBytes.Bytes(), []byte("\n")...)) |
| 291 | + return err == nil |
| 292 | + }) |
| 293 | + |
| 294 | + assert.Equal(t, http.StatusOK, w.Code) |
| 295 | + assert.Equal(t, "", w.Body.String()) |
| 296 | + }) |
| 297 | +} |
| 298 | + |
242 | 299 | type tracedResponseWriter struct {
|
243 | 300 | headers map[string][]string
|
244 | 301 | builder strings.Builder
|
|
0 commit comments