Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions internal/proxy/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,14 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}

// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
// Copy all headers from backend to client before WriteHeader,
// because headers set after WriteHeader are silently ignored.
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)

logger.Debug(ctx, "HTTP start streaming")

Expand Down Expand Up @@ -476,13 +477,14 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}

// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
// Copy all headers from backend to client before WriteHeader,
// because headers set after WriteHeader are silently ignored.
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)

// For TS file, directly copy it.
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
Expand All @@ -502,7 +504,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite

m3u8 := string(b)
if strings.Contains(m3u8, ".ts?") {
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&", v.SRSProxyBackendHLSID))
} else {
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
}
Expand Down
87 changes: 86 additions & 1 deletion internal/proxy/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) {
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) {
if want := "live-0.ts?spbhid=ABC&token=foo"; !strings.Contains(rec.Body.String(), want) {
t.Fatalf("missing %q in body: %q", want, rec.Body.String())
}
}
Expand All @@ -396,6 +396,61 @@ func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) {
}
}

func TestHLSPlayStream_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.apple.mpegurl")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Custom-Header", "custom-value")
_, _ = io.WriteString(w, "#EXTM3U\nlive-0.ts\n")
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)

v := newHLSPlayStream()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}

// Verify headers are properly copied (not lost due to WriteHeader order)
if got := rec.Header().Get("Content-Type"); got != "application/vnd.apple.mpegurl" {
t.Errorf("Content-Type = %q, want application/vnd.apple.mpegurl", got)
}
if got := rec.Header().Get("Cache-Control"); got != "no-cache" {
t.Errorf("Cache-Control = %q, want no-cache", got)
}
if got := rec.Header().Get("X-Custom-Header"); got != "custom-value" {
t.Errorf("X-Custom-Header = %q, want custom-value", got)
}
}

func TestHLSPlayStream_ServeByBackend_TSHeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/mp2t")
w.Header().Set("Cache-Control", "max-age=3600")
_, _ = w.Write([]byte{0x47, 0x00, 0x01})
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)

v := newHLSPlayStream()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}

if got := rec.Header().Get("Content-Type"); got != "video/mp2t" {
t.Errorf("Content-Type = %q, want video/mp2t", got)
}
if got := rec.Header().Get("Cache-Control"); got != "max-age=3600" {
t.Errorf("Cache-Control = %q, want max-age=3600", got)
}
}

// =============================================================================
// httpFlvTsConnection
// =============================================================================
Expand Down Expand Up @@ -666,6 +721,36 @@ func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) {
}
}

func TestHTTPFlvTsConn_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/x-flv")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("X-Custom-Header", "flv-value")
_, _ = w.Write([]byte("FLV\x01\x05\x00\x00\x00\x09"))
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)

v := newHTTPFlvTsConnection()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}

// Verify headers are properly copied (not lost due to WriteHeader order)
if got := rec.Header().Get("Content-Type"); got != "video/x-flv" {
t.Errorf("Content-Type = %q, want video/x-flv", got)
}
if got := rec.Header().Get("Cache-Control"); got != "no-store" {
t.Errorf("Cache-Control = %q, want no-store", got)
}
if got := rec.Header().Get("X-Custom-Header"); got != "flv-value" {
t.Errorf("X-Custom-Header = %q, want flv-value", got)
}
}

// =============================================================================
// httpStreamProxyServer
// =============================================================================
Expand Down
Loading