Skip to content

Commit 6208081

Browse files
authored
Fix: PIA port forwarding (#427)
- Update PIA token URL - Change base64 decoding to standard decoding - Add unit tests - Remove environment variable `GODEBUG=x509ignoreCN=0` - Fixes #423 - Fixes #292 - Closes #264 - Closes #293
1 parent 3795e92 commit 6208081

3 files changed

Lines changed: 175 additions & 36 deletions

File tree

Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,4 @@ RUN apk add -q --progress --no-cache --update openvpn ca-certificates iptables i
147147
deluser unbound && \
148148
mkdir /gluetun
149149
# TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10
150-
ENV GODEBUG=x509ignoreCN=0
151150
COPY --from=build /tmp/gobuild/entrypoint /entrypoint

internal/provider/piav4.go

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
221221
return
222222
}
223223

224-
client, err := newPIAHTTPClient(commonName)
224+
privateIPClient, err := newPIAHTTPClient(commonName)
225225
if err != nil {
226226
pfLogger.Error("aborting because: %s", err)
227227
return
@@ -246,7 +246,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
246246

247247
if !dataFound || expired {
248248
tryUntilSuccessful(ctx, pfLogger, func() error {
249-
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
249+
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
250250
return err
251251
})
252252
if ctx.Err() != nil {
@@ -258,7 +258,10 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
258258

259259
// First time binding
260260
tryUntilSuccessful(ctx, pfLogger, func() error {
261-
return bindPIAPort(ctx, client, gateway, data)
261+
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
262+
return fmt.Errorf("cannot bind port: %w", err)
263+
}
264+
return nil
262265
})
263266
if ctx.Err() != nil {
264267
return
@@ -294,15 +297,15 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
294297
}
295298
return
296299
case <-keepAliveTimer.C:
297-
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
298-
pfLogger.Error(err)
300+
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
301+
pfLogger.Error("cannot bind port: " + err.Error())
299302
}
300303
keepAliveTimer.Reset(keepAlivePeriod)
301304
case <-expiryTimer.C:
302305
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
303306
oldPort := data.Port
304307
for {
305-
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
308+
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
306309
if err != nil {
307310
pfLogger.Error(err)
308311
continue
@@ -322,8 +325,8 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
322325
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
323326
pfLogger.Error(err)
324327
}
325-
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
326-
pfLogger.Error(err)
328+
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
329+
pfLogger.Error("cannot bind port: " + err.Error())
327330
}
328331
if !keepAliveTimer.Stop() {
329332
<-keepAliveTimer.C
@@ -357,41 +360,43 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
357360
if err != nil {
358361
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
359362
}
360-
// certificate.DNSNames = []string{serverName, "10.0.0.1"}
361-
rootCAs := x509.NewCertPool()
362-
rootCAs.AddCert(certificate)
363-
TLSClientConfig := &tls.Config{
364-
RootCAs: rootCAs,
365-
MinVersion: tls.VersionTLS12,
366-
ServerName: serverName,
367-
}
363+
368364
//nolint:gomnd
369-
transport := http.Transport{
370-
TLSClientConfig: TLSClientConfig,
371-
Proxy: http.ProxyFromEnvironment,
365+
transport := &http.Transport{
366+
// Settings taken from http.DefaultTransport
367+
Proxy: http.ProxyFromEnvironment,
372368
DialContext: (&net.Dialer{
373369
Timeout: 30 * time.Second,
374370
KeepAlive: 30 * time.Second,
375-
DualStack: true,
376371
}).DialContext,
377372
ForceAttemptHTTP2: true,
378373
MaxIdleConns: 100,
379374
IdleConnTimeout: 90 * time.Second,
380375
TLSHandshakeTimeout: 10 * time.Second,
381376
ExpectContinueTimeout: 1 * time.Second,
382377
}
378+
rootCAs := x509.NewCertPool()
379+
rootCAs.AddCert(certificate)
380+
transport.TLSClientConfig = &tls.Config{
381+
RootCAs: rootCAs,
382+
MinVersion: tls.VersionTLS12,
383+
ServerName: serverName,
384+
}
385+
383386
const httpTimeout = 30 * time.Second
384-
client = &http.Client{Transport: &transport, Timeout: httpTimeout}
385-
return client, nil
387+
return &http.Client{
388+
Transport: transport,
389+
Timeout: httpTimeout,
390+
}, nil
386391
}
387392

388-
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
393+
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
389394
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
390395
data.Token, err = fetchPIAToken(ctx, openFile, client)
391396
if err != nil {
392397
return data, fmt.Errorf("cannot obtain token: %w", err)
393398
}
394-
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token)
399+
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, privateIPClient, gateway, data.Token)
395400
if err != nil {
396401
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
397402
}
@@ -448,13 +453,15 @@ func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData)
448453
}
449454

450455
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
451-
b, err := base64.RawStdEncoding.DecodeString(payload)
456+
b, err := base64.StdEncoding.DecodeString(payload)
452457
if err != nil {
453-
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err)
458+
return 0, "", expiration,
459+
fmt.Errorf("cannot decode payload: payload is %q: %w", payload, err)
454460
}
455461
var payloadData piaPayload
456462
if err := json.Unmarshal(b, &payloadData); err != nil {
457-
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err)
463+
return 0, "", expiration,
464+
fmt.Errorf("cannot parse payload data: data is %q: %w", string(b), err)
458465
}
459466
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
460467
}
@@ -469,7 +476,7 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
469476
if err != nil {
470477
return "", fmt.Errorf("cannot serialize payload data: %w", err)
471478
}
472-
payload = base64.RawStdEncoding.EncodeToString(b)
479+
payload = base64.StdEncoding.EncodeToString(b)
473480
return payload, nil
474481
}
475482

@@ -482,16 +489,18 @@ func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
482489
url := url.URL{
483490
Scheme: "https",
484491
User: url.UserPassword(username, password),
485-
Host: "10.0.0.1",
486-
Path: "/authv3/generateToken",
492+
Host: "privateinternetaccess.com",
493+
Path: "/gtoken/generateToken",
487494
}
488495
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
489496
if err != nil {
490-
return "", err
497+
return "", replaceInErr(err, map[string]string{
498+
username: "<username>", password: "<password>"})
491499
}
492500
response, err := client.Do(request)
493501
if err != nil {
494-
return "", err
502+
return "", replaceInErr(err, map[string]string{
503+
username: "<username>", password: "<password>"})
495504
}
496505
defer response.Body.Close()
497506
if response.StatusCode != http.StatusOK {
@@ -547,10 +556,12 @@ func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway n
547556
}
548557
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
549558
if err != nil {
559+
err = replaceInErr(err, map[string]string{token: "<token>"})
550560
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
551561
}
552562
response, err := client.Do(request)
553563
if err != nil {
564+
err = replaceInErr(err, map[string]string{token: "<token>"})
554565
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
555566
}
556567
defer response.Body.Close()
@@ -590,11 +601,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
590601

591602
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
592603
if err != nil {
593-
return fmt.Errorf("cannot bind port: %w", err)
604+
return replaceInErr(err, map[string]string{
605+
payload: "<payload>",
606+
data.Signature: "<signature>",
607+
})
594608
}
595609
response, err := client.Do(request)
596610
if err != nil {
597-
return fmt.Errorf("cannot bind port: %w", err)
611+
return replaceInErr(err, map[string]string{
612+
payload: "<payload>",
613+
data.Signature: "<signature>",
614+
})
598615
}
599616
defer response.Body.Close()
600617
if response.StatusCode != http.StatusOK {
@@ -607,7 +624,7 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
607624
Message string `json:"message"`
608625
}
609626
if err := decoder.Decode(&responseData); err != nil {
610-
return fmt.Errorf("cannot bind port: %w", err)
627+
return err
611628
} else if responseData.Status != "OK" {
612629
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
613630
}
@@ -627,3 +644,12 @@ func writePortForwardedToFile(openFile os.OpenFileFunc,
627644
}
628645
return file.Close()
629646
}
647+
648+
// replaceInErr is used to remove sensitive information from logs.
649+
func replaceInErr(err error, substitutions map[string]string) error {
650+
s := err.Error()
651+
for old, new := range substitutions {
652+
s = strings.ReplaceAll(s, old, new)
653+
}
654+
return errors.New(s)
655+
}

internal/provider/piav4_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package provider
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"encoding/base64"
7+
"encoding/json"
8+
"errors"
9+
"net/http"
10+
"testing"
11+
"time"
12+
13+
"github.com/qdm12/gluetun/internal/constants"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func Test_newPIAHTTPClient(t *testing.T) {
19+
t.Parallel()
20+
21+
const serverName = "testserver"
22+
23+
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
24+
require.NoError(t, err)
25+
certificate, err := x509.ParseCertificate(certificateBytes)
26+
require.NoError(t, err)
27+
rootCAs := x509.NewCertPool()
28+
rootCAs.AddCert(certificate)
29+
expectedRootCAsSubjects := rootCAs.Subjects()
30+
31+
expectedPIATransportTLSConfig := &tls.Config{
32+
// Can't directly compare RootCAs because of private fields
33+
RootCAs: nil,
34+
MinVersion: tls.VersionTLS12,
35+
ServerName: serverName,
36+
}
37+
38+
piaClient, err := newPIAHTTPClient(serverName)
39+
40+
require.NoError(t, err)
41+
42+
// Verify pia transport TLS config is set
43+
piaTransport := piaClient.Transport.(*http.Transport)
44+
rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects()
45+
assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects)
46+
piaTransport.TLSClientConfig.RootCAs = nil
47+
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
48+
}
49+
50+
func Test_unpackPIAPayload(t *testing.T) {
51+
t.Parallel()
52+
53+
const exampleToken = "token"
54+
const examplePort = 2000
55+
exampleExpiration := time.Unix(1000, 0).UTC()
56+
57+
testCases := map[string]struct {
58+
payload string
59+
port uint16
60+
token string
61+
expiration time.Time
62+
err error
63+
}{
64+
"valid payload": {
65+
payload: makePIAPayload(t, exampleToken, examplePort, exampleExpiration),
66+
port: examplePort,
67+
token: exampleToken,
68+
expiration: exampleExpiration,
69+
err: nil,
70+
},
71+
"invalid base64 payload": {
72+
payload: "invalid",
73+
err: errors.New(`cannot decode payload: payload is "invalid": illegal base64 data at input byte 4`),
74+
},
75+
"invalid json payload": {
76+
payload: base64.StdEncoding.EncodeToString([]byte{1}),
77+
err: errors.New(`cannot parse payload data: data is "\x01": invalid character '\x01' looking for beginning of value`), //nolint:lll
78+
},
79+
}
80+
81+
for name, testCase := range testCases {
82+
testCase := testCase
83+
t.Run(name, func(t *testing.T) {
84+
t.Parallel()
85+
port, token, expiration, err := unpackPIAPayload(testCase.payload)
86+
87+
if testCase.err != nil {
88+
require.Error(t, err)
89+
assert.Equal(t, testCase.err.Error(), err.Error())
90+
} else {
91+
require.NoError(t, err)
92+
}
93+
94+
assert.Equal(t, testCase.port, port)
95+
assert.Equal(t, testCase.token, token)
96+
assert.Equal(t, testCase.expiration, expiration)
97+
})
98+
}
99+
}
100+
101+
func makePIAPayload(t *testing.T, token string, port uint16, expiration time.Time) (payload string) {
102+
t.Helper()
103+
104+
data := piaPayload{
105+
Token: token,
106+
Port: port,
107+
Expiration: expiration,
108+
}
109+
110+
b, err := json.Marshal(data)
111+
require.NoError(t, err)
112+
113+
return base64.StdEncoding.EncodeToString(b)
114+
}

0 commit comments

Comments
 (0)