Skip to content

Commit 95689eb

Browse files
committed
Optionally serve authentication with HTTPS
1 parent 9376569 commit 95689eb

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

oauth/authcode.go

+54-6
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@ import (
44
"bufio"
55
"crypto/rand"
66
"crypto/sha256"
7+
"crypto/tls"
78
"encoding/base64"
89
"fmt"
910
"net/http"
1011
"net/url"
1112
"os"
1213
"os/exec"
14+
"path"
1315
"runtime"
16+
"strconv"
1417
"strings"
1518
"time"
1619

1720
"context"
1821

1922
"github.com/danielgtaylor/restish/cli"
23+
"github.com/spf13/viper"
2024
"golang.org/x/oauth2"
2125
)
2226

@@ -178,7 +182,7 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
178182

179183
// AuthorizationCodeTokenSource with PKCE as described in:
180184
// https://www.oauth.com/oauth2-servers/pkce/
181-
// This works by running a local HTTP server on port 8484 and then having the
185+
// This works by running a local HTTP or HTTPS server on port 8484 and then having the
182186
// user log in through a web browser, which redirects to the local server with
183187
// an authorization code. That code is then used to make another HTTP request
184188
// to fetch an auth token (and refresh token). That token is then in turn
@@ -190,6 +194,7 @@ type AuthorizationCodeTokenSource struct {
190194
TokenURL string
191195
EndpointParams *url.Values
192196
Scopes []string
197+
UseHTTPS bool
193198
}
194199

195200
// Token generates a new token using an authorization code.
@@ -213,12 +218,22 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
213218
panic(err)
214219
}
215220

221+
redirectURL := url.URL{
222+
Host: "localhost:8484",
223+
Path: "/",
224+
}
225+
if ac.UseHTTPS {
226+
redirectURL.Scheme = "https"
227+
} else {
228+
redirectURL.Scheme = "http"
229+
}
230+
216231
aq := authorizeURL.Query()
217232
aq.Set("response_type", "code")
218233
aq.Set("code_challenge", challenge)
219234
aq.Set("code_challenge_method", "S256")
220235
aq.Set("client_id", ac.ClientID)
221-
aq.Set("redirect_uri", "http://localhost:8484/")
236+
aq.Set("redirect_uri", redirectURL.String())
222237
aq.Set("scope", strings.Join(ac.Scopes, " "))
223238
if ac.EndpointParams != nil {
224239
for k, v := range *ac.EndpointParams {
@@ -234,16 +249,38 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
234249
}
235250

236251
s := &http.Server{
237-
Addr: "localhost:8484",
252+
Addr: redirectURL.Host,
238253
Handler: handler,
239254
ReadTimeout: 5 * time.Second,
240255
WriteTimeout: 5 * time.Second,
241256
MaxHeaderBytes: 1024,
242257
}
243258

259+
if ac.UseHTTPS {
260+
configDirectory := viper.GetString("config-directory")
261+
certName := path.Join(configDirectory, "localhost.crt")
262+
keyfileName := path.Join(configDirectory, "localhost.key")
263+
264+
cert, err := tls.LoadX509KeyPair(certName, keyfileName)
265+
if err != nil {
266+
panic(err)
267+
}
268+
269+
s.TLSConfig = &tls.Config{
270+
Certificates: []tls.Certificate{cert},
271+
}
272+
}
273+
244274
go func() {
245275
// Run in a goroutine until the server is closed or we get an error.
246-
if err := s.ListenAndServe(); err != http.ErrServerClosed {
276+
var err error
277+
if ac.UseHTTPS {
278+
err = s.ListenAndServeTLS("", "")
279+
} else {
280+
err = s.ListenAndServe()
281+
}
282+
283+
if err != http.ErrServerClosed {
247284
panic(err)
248285
}
249286
}()
@@ -279,7 +316,7 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
279316
payload.Set("client_id", ac.ClientID)
280317
payload.Set("code_verifier", verifier)
281318
payload.Set("code", code)
282-
payload.Set("redirect_uri", "http://localhost:8484/")
319+
payload.Set("redirect_uri", redirectURL.String())
283320
if ac.ClientSecret != "" {
284321
payload.Set("client_secret", ac.ClientSecret)
285322
}
@@ -299,6 +336,7 @@ func (h *AuthorizationCodeHandler) Parameters() []cli.AuthParam {
299336
{Name: "authorize_url", Required: true, Help: "OAuth 2.0 authorization URL, e.g. https://api.example.com/oauth/authorize"},
300337
{Name: "token_url", Required: true, Help: "OAuth 2.0 token URL, e.g. https://api.example.com/oauth/token"},
301338
{Name: "scopes", Help: "Optional scopes to request in the token"},
339+
{Name: "use_https", Help: "Use HTTPS for authentication page"},
302340
}
303341
}
304342

@@ -307,21 +345,31 @@ func (h *AuthorizationCodeHandler) OnRequest(request *http.Request, key string,
307345
if request.Header.Get("Authorization") == "" {
308346
endpointParams := url.Values{}
309347
for k, v := range params {
310-
if k == "client_id" || k == "client_secret" || k == "scopes" || k == "authorize_url" || k == "token_url" {
348+
if k == "client_id" || k == "client_secret" || k == "scopes" || k == "authorize_url" || k == "token_url" || k == "use_https" {
311349
// Not a custom param...
312350
continue
313351
}
314352

315353
endpointParams.Add(k, v)
316354
}
317355

356+
var useHTTPS bool
357+
if v := params["use_https"]; v != "" {
358+
var err error
359+
useHTTPS, err = strconv.ParseBool(v)
360+
if err != nil {
361+
return err
362+
}
363+
}
364+
318365
source := &AuthorizationCodeTokenSource{
319366
ClientID: params["client_id"],
320367
ClientSecret: params["client_secret"],
321368
AuthorizeURL: params["authorize_url"],
322369
TokenURL: params["token_url"],
323370
EndpointParams: &endpointParams,
324371
Scopes: strings.Split(params["scopes"], ","),
372+
UseHTTPS: useHTTPS,
325373
}
326374

327375
// Try to get a cached refresh token from the current profile and use

0 commit comments

Comments
 (0)