Skip to content

Commit 5302694

Browse files
committed
add proxy polling restriction
Signed-off-by: peekjef72 <[email protected]>
1 parent 83cae52 commit 5302694

File tree

2 files changed

+223
-43
lines changed

2 files changed

+223
-43
lines changed

Diff for: cmd/proxy/coordinator.go

+21-6
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ func (c *Coordinator) getRequestChannel(fqdn string) chan *http.Request {
8888
return ch
8989
}
9090

91+
func (c *Coordinator) checkRequestChannel(fqdn string) bool {
92+
c.mu.Lock()
93+
defer c.mu.Unlock()
94+
_, ok := c.waiting[fqdn]
95+
return ok
96+
}
97+
9198
func (c *Coordinator) getResponseChannel(id string) chan *http.Response {
9299
c.mu.Lock()
93100
defer c.mu.Unlock()
@@ -116,7 +123,7 @@ func (c *Coordinator) DoScrape(ctx context.Context, r *http.Request) (*http.Resp
116123
r.Header.Add("Id", id)
117124
select {
118125
case <-ctx.Done():
119-
return nil, fmt.Errorf("Timeout reached for %q: %s", r.URL.String(), ctx.Err())
126+
return nil, fmt.Errorf("timeout reached for %q: %s", r.URL.String(), ctx.Err())
120127
case c.getRequestChannel(r.URL.Hostname()) <- r:
121128
}
122129

@@ -189,15 +196,23 @@ func (c *Coordinator) addKnownClient(fqdn string) {
189196
}
190197

191198
// KnownClients returns a list of alive clients
192-
func (c *Coordinator) KnownClients() []string {
199+
func (c *Coordinator) KnownClients(client string) []string {
193200
c.mu.Lock()
194201
defer c.mu.Unlock()
195202

203+
var known []string
196204
limit := time.Now().Add(-*registrationTimeout)
197-
known := make([]string, 0, len(c.known))
198-
for k, t := range c.known {
199-
if limit.Before(t) {
200-
known = append(known, k)
205+
if client != "" {
206+
known = make([]string, 0, 1)
207+
if t, ok := c.known[client]; ok && limit.Before(t) {
208+
known = append(known, client)
209+
}
210+
} else {
211+
known = make([]string, 0, len(c.known))
212+
for k, t := range c.known {
213+
if limit.Before(t) {
214+
known = append(known, k)
215+
}
201216
}
202217
}
203218
return known

Diff for: cmd/proxy/main.go

+202-37
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"io"
23+
"net"
2324
"net/http"
2425
"os"
26+
"regexp"
2527
"strings"
2628

2729
"github.com/alecthomas/kingpin/v2"
@@ -43,6 +45,7 @@ var (
4345
listenAddress = kingpin.Flag("web.listen-address", "Address to listen on for proxy and client requests.").Default(":8080").String()
4446
maxScrapeTimeout = kingpin.Flag("scrape.max-timeout", "Any scrape with a timeout higher than this will have to be clamped to this.").Default("5m").Duration()
4547
defaultScrapeTimeout = kingpin.Flag("scrape.default-timeout", "If a scrape lacks a timeout, use this value.").Default("15s").Duration()
48+
authorizedPollers = kingpin.Flag("scrape.pollers-ip", "Comma separeted list of ips addresses or networks authorized to scrap via the proxy.").Default("").String()
4649
)
4750

4851
var (
@@ -63,7 +66,10 @@ var (
6366
prometheus.HistogramOpts{
6467
Name: "pushprox_http_duration_seconds",
6568
Help: "Time taken by path",
66-
}, []string{"path"})
69+
}, []string{"path"},
70+
)
71+
72+
// hasPollersNet = false
6773
)
6874

6975
func init() {
@@ -83,38 +89,86 @@ type targetGroup struct {
8389
Labels map[string]string `json:"labels"`
8490
}
8591

92+
const (
93+
OpEgals = 1
94+
OpMatch = 2
95+
)
96+
97+
type route struct {
98+
path string
99+
regex *regexp.Regexp
100+
handler http.HandlerFunc
101+
}
102+
103+
func newRoute(op int, path string, handler http.HandlerFunc) *route {
104+
if op == OpEgals {
105+
return &route{path, nil, handler}
106+
} else if op == OpMatch {
107+
return &route{"", regexp.MustCompile("^" + path + "$"), handler}
108+
109+
} else {
110+
return nil
111+
}
112+
113+
}
114+
86115
type httpHandler struct {
87116
logger log.Logger
88117
coordinator *Coordinator
89118
mux http.Handler
90119
proxy http.Handler
120+
pollersNet map[*net.IPNet]int
91121
}
92122

93-
func newHTTPHandler(logger log.Logger, coordinator *Coordinator, mux *http.ServeMux) *httpHandler {
94-
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux}
95-
96-
// api handlers
97-
handlers := map[string]http.HandlerFunc{
98-
"/push": h.handlePush,
99-
"/poll": h.handlePoll,
100-
"/clients": h.handleListClients,
101-
"/metrics": promhttp.Handler().ServeHTTP,
102-
}
103-
for path, handlerFunc := range handlers {
104-
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
105-
handler := promhttp.InstrumentHandlerCounter(counter, http.HandlerFunc(handlerFunc))
106-
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
107-
handler = promhttp.InstrumentHandlerDuration(histogram, handler)
108-
mux.Handle(path, handler)
109-
counter.WithLabelValues("200")
110-
if path == "/push" {
111-
counter.WithLabelValues("500")
112-
}
113-
if path == "/poll" {
114-
counter.WithLabelValues("408")
115-
}
123+
func newHTTPHandler(logger log.Logger, coordinator *Coordinator, mux *http.ServeMux, pollers map[*net.IPNet]int) *httpHandler {
124+
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux, pollersNet: pollers}
125+
126+
var routes = []*route{
127+
newRoute(OpEgals, "/push", h.handlePush),
128+
newRoute(OpEgals, "/poll", h.handlePoll),
129+
newRoute(OpMatch, "/clients(/.*)?", h.handleListClients),
130+
newRoute(OpEgals, "/metrics", promhttp.Handler().ServeHTTP),
116131
}
132+
hf := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
133+
for _, route := range routes {
134+
var path string
117135

136+
if route == nil {
137+
continue
138+
}
139+
if route.regex != nil {
140+
if strings.HasPrefix(route.path, "/clients") {
141+
path = "/clients"
142+
}
143+
} else if req.URL.Path == route.path {
144+
path = route.path
145+
}
146+
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
147+
handler := promhttp.InstrumentHandlerCounter(counter, route.handler)
148+
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
149+
route.handler = promhttp.InstrumentHandlerDuration(histogram, handler)
150+
// mux.Handle(route.path, handler)
151+
counter.WithLabelValues("200")
152+
if route.path == "/push" {
153+
counter.WithLabelValues("500")
154+
}
155+
if route.path == "/poll" {
156+
counter.WithLabelValues("408")
157+
}
158+
if route.regex != nil {
159+
if route.regex != nil {
160+
if route.regex.MatchString(req.URL.Path) {
161+
route.handler(w, req)
162+
return
163+
}
164+
}
165+
} else if req.URL.Path == route.path {
166+
route.handler(w, req)
167+
return
168+
}
169+
}
170+
})
171+
h.mux = hf
118172
// proxy handler
119173
h.proxy = promhttp.InstrumentHandlerCounter(httpProxyCounter, http.HandlerFunc(h.handleProxy))
120174

@@ -128,15 +182,15 @@ func (h *httpHandler) handlePush(w http.ResponseWriter, r *http.Request) {
128182
scrapeResult, err := http.ReadResponse(bufio.NewReader(buf), nil)
129183
if err != nil {
130184
level.Error(h.logger).Log("msg", "Error reading pushed response:", "err", err)
131-
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
185+
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
132186
return
133187
}
134188
scrapeId := scrapeResult.Header.Get("Id")
135189
level.Info(h.logger).Log("msg", "Got /push", "scrape_id", scrapeId)
136190
err = h.coordinator.ScrapeResult(scrapeResult)
137191
if err != nil {
138192
level.Error(h.logger).Log("msg", "Error pushing:", "err", err, "scrape_id", scrapeId)
139-
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
193+
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
140194
}
141195
}
142196

@@ -146,29 +200,105 @@ func (h *httpHandler) handlePoll(w http.ResponseWriter, r *http.Request) {
146200
request, err := h.coordinator.WaitForScrapeInstruction(strings.TrimSpace(string(fqdn)))
147201
if err != nil {
148202
level.Info(h.logger).Log("msg", "Error WaitForScrapeInstruction:", "err", err)
149-
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), 408)
203+
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), http.StatusRequestTimeout)
150204
return
151205
}
152206
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
153207
request.WriteProxy(w) // Send full request as the body of the response.
154208
level.Info(h.logger).Log("msg", "Responded to /poll", "url", request.URL.String(), "scrape_id", request.Header.Get("Id"))
155209
}
156210

211+
// isPoller checks if caller has an IP addr in authorized nets (if any defined). It uses RemoteAddr field
212+
// from http.Request.
213+
// RETURNS:
214+
// - true and "" if no restriction is defined
215+
// - true and clientip if @ip from RemoteAddr is found in allowed nets
216+
// - false and "" else
217+
func (h *httpHandler) isPoller(r *http.Request) (bool, string) {
218+
var (
219+
ispoller = false
220+
clientip string
221+
)
222+
223+
if len(h.pollersNet) > 0 {
224+
if i := strings.Index(r.RemoteAddr, ":"); i != -1 {
225+
clientip = r.RemoteAddr[0:i]
226+
}
227+
for key := range h.pollersNet {
228+
ip := net.ParseIP(clientip)
229+
if key.Contains(ip) {
230+
ispoller = true
231+
break
232+
}
233+
}
234+
} else {
235+
ispoller = true
236+
}
237+
return ispoller, clientip
238+
}
239+
157240
// handleListClients handles requests to list available clients as a JSON array.
158241
func (h *httpHandler) handleListClients(w http.ResponseWriter, r *http.Request) {
159-
known := h.coordinator.KnownClients()
160-
targets := make([]*targetGroup, 0, len(known))
161-
for _, k := range known {
162-
targets = append(targets, &targetGroup{Targets: []string{k}})
242+
var (
243+
targets []*targetGroup
244+
lknown int
245+
client string
246+
)
247+
248+
ispoller, clientip := h.isPoller(r)
249+
// if not a poller we are not authorized to get all clients, restrict query to itself hostname
250+
if !ispoller {
251+
hosts, err := net.LookupAddr(clientip)
252+
if err != nil {
253+
level.Error(h.logger).Log("msg", "can't reverse client address", "err", err.Error())
254+
}
255+
if len(hosts) > 0 {
256+
// level.Info(h.logger).Log("hosts", fmt.Sprintf("%v", hosts))
257+
client = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
258+
} else {
259+
client = "_not_found_hostname_"
260+
}
261+
} else {
262+
if len(r.URL.Path) > 9 {
263+
client = r.URL.Path[9:]
264+
}
163265
}
164-
w.Header().Set("Content-Type", "application/json")
165-
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
166-
json.NewEncoder(w).Encode(targets)
167-
level.Info(h.logger).Log("msg", "Responded to /clients", "client_count", len(known))
266+
known := h.coordinator.KnownClients(client)
267+
lknown = len(known)
268+
if client != "" && lknown == 0 {
269+
http.Error(w, "", http.StatusNotFound)
270+
} else {
271+
targets = make([]*targetGroup, 0, lknown)
272+
for _, k := range known {
273+
targets = append(targets, &targetGroup{Targets: []string{k}})
274+
}
275+
w.Header().Set("Content-Type", "application/json")
276+
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
277+
json.NewEncoder(w).Encode(targets)
278+
}
279+
level.Info(h.logger).Log("msg", "Responded to /clients", "client_count", lknown)
168280
}
169281

170282
// handleProxy handles proxied scrapes from Prometheus.
171283
func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
284+
if ok, clientip := h.isPoller(r); !ok {
285+
var clientfqdn string
286+
hosts, err := net.LookupAddr(clientip)
287+
if err != nil {
288+
level.Error(h.logger).Log("msg", "can't reverse client address", "err", err.Error())
289+
}
290+
if len(hosts) > 0 {
291+
// level.Info(h.logger).Log("hosts", fmt.Sprintf("%v", hosts))
292+
clientfqdn = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
293+
} else {
294+
clientfqdn = "_not_found_hostname_"
295+
}
296+
if !h.coordinator.checkRequestChannel(clientfqdn) {
297+
http.Error(w, "Not an authorized poller", http.StatusForbidden)
298+
return
299+
}
300+
}
301+
172302
ctx, cancel := context.WithTimeout(r.Context(), util.GetScrapeTimeout(maxScrapeTimeout, defaultScrapeTimeout, r.Header))
173303
defer cancel()
174304
request := r.WithContext(ctx)
@@ -177,7 +307,7 @@ func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
177307
resp, err := h.coordinator.DoScrape(ctx, request)
178308
if err != nil {
179309
level.Error(h.logger).Log("msg", "Error scraping:", "err", err, "url", request.URL.String())
180-
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), 500)
310+
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), http.StatusInternalServerError)
181311
return
182312
}
183313
defer resp.Body.Close()
@@ -193,6 +323,18 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
193323
}
194324
}
195325

326+
// return list of network addresses from the httpHandlet.pollersNet map
327+
func (h *httpHandler) pollersNetString() string {
328+
if len(h.pollersNet) > 0 {
329+
l := make([]string, 0, len(h.pollersNet))
330+
for netw := range h.pollersNet {
331+
l = append(l, netw.String())
332+
}
333+
return strings.Join(l, ",")
334+
} else {
335+
return ""
336+
}
337+
}
196338
func main() {
197339
promlogConfig := promlog.Config{}
198340
flag.AddFlags(kingpin.CommandLine, &promlogConfig)
@@ -204,11 +346,34 @@ func main() {
204346
level.Error(logger).Log("msg", "Coordinator initialization failed", "err", err)
205347
os.Exit(1)
206348
}
349+
pollersNet := make(map[*net.IPNet]int, 10)
350+
if *authorizedPollers != "" {
351+
networks := strings.Split(*authorizedPollers, ",")
352+
for _, network := range networks {
353+
if !strings.Contains(network, "/") {
354+
// detect ipv6
355+
if strings.Contains(network, ":") {
356+
network = fmt.Sprintf("%s/128", network)
357+
} else {
358+
network = fmt.Sprintf("%s/32", network)
359+
}
360+
}
361+
if _, subnet, err := net.ParseCIDR(network); err != nil {
362+
level.Error(logger).Log("msg", "network is invalid", "net", network, "err", err)
363+
os.Exit(1)
364+
} else {
365+
pollersNet[subnet] = 1
366+
}
367+
}
368+
}
207369

208370
mux := http.NewServeMux()
209-
handler := newHTTPHandler(logger, coordinator, mux)
371+
handler := newHTTPHandler(logger, coordinator, mux, pollersNet)
210372

211373
level.Info(logger).Log("msg", "Listening", "address", *listenAddress)
374+
if len(pollersNet) > 0 {
375+
level.Info(logger).Log("msg", "Polling restricted", "allowed", handler.pollersNetString())
376+
}
212377
if err := http.ListenAndServe(*listenAddress, handler); err != nil {
213378
level.Error(logger).Log("msg", "Listening failed", "err", err)
214379
os.Exit(1)

0 commit comments

Comments
 (0)