Skip to content

Commit 0674631

Browse files
committed
add max wait and block requests
1 parent 43c9e93 commit 0674631

File tree

3 files changed

+103
-16
lines changed

3 files changed

+103
-16
lines changed

.golangci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ run:
44
linters:
55
enable-all: true
66
disable:
7+
- wsl
78
- contextcheck
89
- gomnd
910
- gochecknoinits

cmd/serve/main.go

+24-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"flag"
55
"log"
66
"net/http"
7+
"strings"
8+
"time"
79

810
"github.com/go-rod/bartender"
911
)
@@ -12,6 +14,10 @@ func main() {
1214
port := flag.String("p", ":3001", "port to listen on")
1315
target := flag.String("t", "", "target url to proxy")
1416
size := flag.Int("s", 2, "size of the pool")
17+
maxWait := flag.Duration("w", 3*time.Second, "max wait time for a page rendering")
18+
19+
var block BlockRequestsFlag
20+
flag.Var(&block, "b", "block the requests that match the pattern, such as 'https://a.com/*', can set multiple ones")
1521

1622
flag.Parse()
1723

@@ -21,8 +27,25 @@ func main() {
2127

2228
log.Printf("Bartender started %s -> %s\n", *port, *target)
2329

24-
err := http.ListenAndServe(*port, bartender.New(*port, *target, *size))
30+
b := bartender.New(*port, *target, *size)
31+
b.BlockRequest(block...)
32+
b.MaxWait(*maxWait)
33+
b.WarnUp()
34+
35+
err := http.ListenAndServe(*port, b)
2536
if err != nil {
2637
log.Fatalln(err)
2738
}
2839
}
40+
41+
type BlockRequestsFlag []string
42+
43+
func (i *BlockRequestsFlag) String() string {
44+
return strings.Join(*i, ", ")
45+
}
46+
47+
func (i *BlockRequestsFlag) Set(value string) error {
48+
*i = append(*i, value)
49+
50+
return nil
51+
}

service.go

+78-15
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@ import (
88
"net/http/httputil"
99
"net/url"
1010
"strings"
11+
"sync"
12+
"time"
1113

1214
"github.com/go-rod/rod"
15+
"github.com/go-rod/rod/lib/proto"
1316
"github.com/mileusna/useragent"
1417
)
1518

1619
type Bartender struct {
17-
addr string
18-
target *url.URL
19-
proxy *httputil.ReverseProxy
20-
bypassList map[string]bool
21-
pool rod.PagePool
20+
addr string
21+
target *url.URL
22+
proxy *httputil.ReverseProxy
23+
bypassList map[string]bool
24+
pool rod.PagePool
25+
blockRequests []string
26+
maxWait time.Duration
2227
}
2328

2429
func New(addr, target string, poolSize int) *Bartender {
@@ -45,14 +50,53 @@ func New(addr, target string, poolSize int) *Bartender {
4550
useragent.Edge: true,
4651
useragent.Vivaldi: true,
4752
},
48-
pool: rod.NewPagePool(poolSize),
53+
pool: rod.NewPagePool(poolSize),
54+
blockRequests: []string{},
55+
maxWait: 3 * time.Second,
4956
}
5057
}
5158

5259
func (b *Bartender) BypassUserAgentNames(list map[string]bool) {
5360
b.bypassList = list
5461
}
5562

63+
func (b *Bartender) BlockRequest(patterns ...string) {
64+
b.blockRequests = patterns
65+
}
66+
67+
// MaxWait sets the max wait time for the headless browser to render the page.
68+
// If the max wait time is reached, bartender will stop waiting for page rendering and
69+
// immediately return the current html.
70+
func (b *Bartender) MaxWait(d time.Duration) {
71+
b.maxWait = d
72+
}
73+
74+
func (b *Bartender) newPage() *rod.Page {
75+
page := rod.New().MustConnect().MustPage()
76+
77+
if len(b.blockRequests) > 0 {
78+
router := page.HijackRequests()
79+
80+
for _, pattern := range b.blockRequests {
81+
router.MustAdd(pattern, func(ctx *rod.Hijack) {
82+
ctx.Response.Fail(proto.NetworkErrorReasonBlockedByClient)
83+
})
84+
}
85+
86+
go router.Run()
87+
}
88+
89+
log.Println("headless browser started:", page.SessionID)
90+
91+
return page
92+
}
93+
94+
func (b *Bartender) WarnUp() {
95+
for i := 0; i < len(b.pool); i++ {
96+
b.pool.Put(b.pool.Get(b.newPage))
97+
}
98+
}
99+
56100
func (b *Bartender) ServeHTTP(w http.ResponseWriter, r *http.Request) {
57101
ua := useragent.Parse(r.Header.Get("User-Agent"))
58102
if r.Method != http.MethodGet || b.bypassList[ua.Name] {
@@ -80,11 +124,6 @@ func (b *Bartender) RenderPage(w http.ResponseWriter, r *http.Request) bool {
80124

81125
log.Println("headless render:", u)
82126

83-
page := b.pool.Get(func() *rod.Page { return rod.New().MustConnect().MustPage() })
84-
defer b.pool.Put(page)
85-
86-
page.MustNavigate(u).MustWaitStable()
87-
88127
for k, vs := range resHeader {
89128
if k == "Content-Length" {
90129
continue
@@ -97,10 +136,34 @@ func (b *Bartender) RenderPage(w http.ResponseWriter, r *http.Request) bool {
97136

98137
w.WriteHeader(statusCode)
99138

100-
_, err := w.Write([]byte(page.MustHTML()))
101-
if err != nil {
102-
panic(err)
103-
}
139+
page := b.pool.Get(b.newPage)
140+
defer b.pool.Put(page)
141+
142+
page, cancel := page.WithCancel()
143+
144+
once := sync.Once{}
145+
146+
go func() {
147+
time.Sleep(b.maxWait)
148+
log.Println("max wait time reached, return current html:", u)
149+
once.Do(func() {
150+
body, _ := page.HTML()
151+
_, _ = w.Write([]byte(body))
152+
cancel()
153+
})
154+
}()
155+
156+
_ = page.Context(r.Context()).Navigate(u)
157+
158+
_ = page.WaitStable(time.Second)
159+
160+
body, _ := page.HTML()
161+
162+
log.Println("headless render done:", u)
163+
164+
once.Do(func() {
165+
_, _ = w.Write([]byte(body))
166+
})
104167

105168
return true
106169
}

0 commit comments

Comments
 (0)