-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy pathtrickle_publisher.go
303 lines (263 loc) · 7.19 KB
/
trickle_publisher.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
package trickle
import (
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
)
var StreamNotFoundErr = errors.New("stream not found")
// TricklePublisher represents a trickle streaming client
type TricklePublisher struct {
client *http.Client
baseURL string
index int // Current index for segments
writeLock sync.Mutex // Mutex to manage concurrent access
pendingPost *pendingPost // Pre-initialized POST request
contentType string
}
// HTTPError gets returned with a >=400 status code (non-400)
type HTTPError struct {
Code int
Body string
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("Status code %d - %s", e.Code, e.Body)
}
// pendingPost represents a pre-initialized POST request waiting for data
type pendingPost struct {
index int
writer *io.PipeWriter
errCh chan error
// needed to help with reconnects
written bool
client *TricklePublisher
}
// NewTricklePublisher creates a new trickle stream client
func NewTricklePublisher(url string) (*TricklePublisher, error) {
c := &TricklePublisher{
baseURL: url,
contentType: "video/MP2T",
client: httpClient(),
}
p, err := c.preconnect()
if err != nil {
return nil, err
}
c.pendingPost = p
return c, nil
}
// NB expects to have the lock already since we mutate the index
func (c *TricklePublisher) preconnect() (*pendingPost, error) {
index := c.index
url := fmt.Sprintf("%s/%d", c.baseURL, index)
slog.Debug("Preconnecting", "url", url)
errCh := make(chan error, 1)
pr, pw := io.Pipe()
req, err := http.NewRequest("POST", url, pr)
if err != nil {
slog.Error("Failed to create request for segment", "url", url, "err", err)
return nil, err
}
req.Header.Set("Content-Type", c.contentType)
// Start the POST request in a background goroutine
go func() {
resp, err := c.client.Do(req)
if err != nil {
slog.Error("Failed to complete POST for segment", "url", url, "err", err)
errCh <- err
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Error("Error reading body", "url", url, "err", err)
errCh <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
slog.Error("Failed POST segment", "url", url, "status_code", resp.StatusCode, "msg", string(body))
if resp.StatusCode == http.StatusNotFound {
errCh <- StreamNotFoundErr
return
}
if resp.StatusCode >= 400 {
errCh <- &HTTPError{Code: resp.StatusCode, Body: string(body)}
return
}
} else {
slog.Debug("Uploaded segment", "url", url)
}
errCh <- nil
}()
c.index += 1
return &pendingPost{
writer: pw,
index: index,
errCh: errCh,
client: c,
}, nil
}
func (c *TricklePublisher) Close() error {
req, err := http.NewRequest("DELETE", c.baseURL, nil)
if err != nil {
return err
}
// Use a new client for a fresh connection
resp, err := httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("Failed to delete stream: %v - %s", resp.Status, string(body))
}
return nil
}
func (c *TricklePublisher) Next() (*pendingPost, error) {
// Acquire lock to manage access to pendingPost and index
c.writeLock.Lock()
defer c.writeLock.Unlock()
// Get the writer to use
pp := c.pendingPost
if pp == nil {
p, err := c.preconnect()
if err != nil {
c.writeLock.Unlock()
return nil, err
}
pp = p
}
// Set up the next connection
nextPost, err := c.preconnect()
if err != nil {
c.writeLock.Unlock()
return nil, err
}
c.pendingPost = nextPost
return pp, nil
}
func (p *pendingPost) reconnect() (*pendingPost, error) {
// This is a little gnarly but works for now:
// Set the publisher's sequence sequence to the intended reconnect
// Call publisher's preconnect (which increments its sequence)
// then reset publisher's sequence back to the original
// Also recreate the client to force a fresh connection
p.client.writeLock.Lock()
defer p.client.writeLock.Unlock()
currentSeq := p.client.index
p.client.index = p.index
p.client.client = httpClient()
pp, err := p.client.preconnect()
p.client.index = currentSeq
return pp, err
}
func (p *pendingPost) Write(data io.Reader) (int64, error) {
// If writing multiple times, reconnect
if p.written {
pp, err := p.reconnect()
if err != nil {
return 0, err
}
p = pp
}
var (
writer = p.writer
index = p.index
errCh = p.errCh
)
// Mark as written
p.written = true
// before writing, check for error from preconnects
select {
case err := <-errCh:
return 0, err
default:
// no error, continue
}
// Start streaming data to the current POST request
n, ioError := io.Copy(writer, data)
// if no io errors, close the writer
var closeErr error
if ioError == nil {
slog.Debug("Completed writing", "idx", index, "totalBytes", humanBytes(n))
// Close the pipe writer to signal end of data for the current POST request
closeErr = writer.Close()
}
// check for errors after write, eg >=400 status codes
// these typically do not result in io errors eg, with io.Copy
// also prioritize errors over this channel compared to io errors
// such as "read/write on closed pipe"
if err := <-errCh; err != nil {
return n, err
}
if ioError != nil {
return n, fmt.Errorf("error streaming data to segment %d: %w", index, ioError)
}
if closeErr != nil {
return n, fmt.Errorf("error closing writer for segment %d: %w", index, closeErr)
}
return n, nil
}
/*
Close a segment. This is a polite action to notify any
subscribers that might be waiting for this segment.
Only needed if the segment is dropped or otherwise errored;
not required if the segment is written normally.
Note that subscribers still work fine even without this call;
it would just take longer for them to stop waiting when
the current segment drops out of the window of active segments.
*/
func (p *pendingPost) Close() error {
p.writer.Close()
url := fmt.Sprintf("%s/%d", p.client.baseURL, p.index)
req, err := http.NewRequest("DELETE", url, nil)
if err != nil {
return err
}
// Since this method typically gets invoked when
// there is a problem sending the segment, use a
// new client for a fresh connection just in case
resp, err := httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return &HTTPError{Code: resp.StatusCode, Body: string(body)}
}
return nil
}
// Write sends data to the current segment, sets up the next segment concurrently, and blocks until completion
func (c *TricklePublisher) Write(data io.Reader) error {
pp, err := c.Next()
if err != nil {
return err
}
_, err = pp.Write(data)
return err
}
func httpClient() *http.Client {
return &http.Client{Transport: &http.Transport{
// Re-enable keepalives to avoid connection pooling
// DisableKeepAlives: true,
// ignore orch certs for now
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
}
func humanBytes(bytes int64) string {
var unit int64 = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := unit, 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}