diff --git a/.gitignore b/.gitignore index 2a2b503..4ceafef 100644 --- a/.gitignore +++ b/.gitignore @@ -30,4 +30,12 @@ go.sum # OS generated files .DS_Store +# builds openmcpauthproxy + +# test out files +coverage.out +coverage.html + +# IDE files +.vscode diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7fd2574 --- /dev/null +++ b/Makefile @@ -0,0 +1,73 @@ +# Makefile for open-mcp-auth-proxy + +# Variables +BINARY_NAME := openmcpauthproxy +GO := go +GOFMT := gofmt +GOVET := go vet +GOTEST := go test +GOLINT := golangci-lint +GOCOV := go tool cover +BUILD_DIR := build + +# Source files +SRC := $(shell find . -name "*.go" -not -path "./vendor/*") +PKGS := $(shell go list ./... | grep -v /vendor/) + +# Set build options +BUILD_OPTS := -v + +# Set test options +TEST_OPTS := -v -race + +.PHONY: all build clean test fmt lint vet coverage help + +# Default target +all: lint test build + +# Build the application +build: + @echo "Building $(BINARY_NAME)..." + @mkdir -p $(BUILD_DIR) + $(GO) build $(BUILD_OPTS) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/proxy + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + @rm -rf $(BUILD_DIR) + @rm -f coverage.out + +# Run tests +test: + @echo "Running tests..." + $(GOTEST) $(TEST_OPTS) ./... + +# Run tests with coverage report +coverage: + @echo "Running tests with coverage..." + @$(GOTEST) -coverprofile=coverage.out ./... + @$(GOCOV) -func=coverage.out + @$(GOCOV) -html=coverage.out -o coverage.html + @echo "Coverage report generated in coverage.html" + +# Run gofmt +fmt: + @echo "Running gofmt..." + @$(GOFMT) -w -s $(SRC) + +# Run go vet +vet: + @echo "Running go vet..." + @$(GOVET) ./... + +# Show help +help: + @echo "Available targets:" + @echo " all : Run lint, test, and build" + @echo " build : Build the application" + @echo " clean : Clean build artifacts" + @echo " test : Run tests" + @echo " coverage : Run tests with coverage report" + @echo " fmt : Run gofmt" + @echo " vet : Run go vet" + @echo " help : Show this help message" diff --git a/internal/authz/default_test.go b/internal/authz/default_test.go new file mode 100644 index 0000000..f40030f --- /dev/null +++ b/internal/authz/default_test.go @@ -0,0 +1,125 @@ +package authz + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +func TestNewDefaultProvider(t *testing.T) { + cfg := &config.Config{} + provider := NewDefaultProvider(cfg) + + if provider == nil { + t.Fatal("Expected non-nil provider") + } + + // Ensure it implements the Provider interface + var _ Provider = provider +} + +func TestDefaultProviderWellKnownHandler(t *testing.T) { + // Create a config with a custom well-known response + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/.well-known/oauth-authorization-server": { + Response: &config.ResponseConfig{ + Issuer: "https://test-issuer.com", + JwksURI: "https://test-issuer.com/jwks", + ResponseTypesSupported: []string{"code"}, + GrantTypesSupported: []string{"authorization_code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + }, + }, + }, + }, + } + + provider := NewDefaultProvider(cfg) + handler := provider.WellKnownHandler() + + // Create a test request + req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil) + req.Host = "test-host.com" + req.Header.Set("X-Forwarded-Proto", "https") + + // Create a response recorder + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response status + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %v", w.Code) + } + + // Verify content type + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %s", contentType) + } + + // Decode and check the response body + var response map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response JSON: %v", err) + } + + // Check expected values + if response["issuer"] != "https://test-issuer.com" { + t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"]) + } + if response["jwks_uri"] != "https://test-issuer.com/jwks" { + t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"]) + } + if response["authorization_endpoint"] != "https://test-host.com/authorize" { + t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"]) + } +} + +func TestDefaultProviderHandleOPTIONS(t *testing.T) { + provider := NewDefaultProvider(&config.Config{}) + handler := provider.WellKnownHandler() + + // Create OPTIONS request + req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil) + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response + if w.Code != http.StatusNoContent { + t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code) + } + + // Check CORS headers + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin")) + } + if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" { + t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods")) + } +} + +func TestDefaultProviderInvalidMethod(t *testing.T) { + provider := NewDefaultProvider(&config.Config{}) + handler := provider.WellKnownHandler() + + // Create POST request (which should be rejected) + req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil) + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..20c0893 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,196 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfig(t *testing.T) { + // Create a temporary config file + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "test_config.yaml") + + // Basic valid config + validConfig := ` +listen_port: 8080 +base_url: "http://localhost:8000" +transport_mode: "sse" +paths: + sse: "/sse" + messages: "/messages" +cors: + allowed_origins: + - "http://localhost:5173" + allowed_methods: + - "GET" + - "POST" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true +` + err := os.WriteFile(configPath, []byte(validConfig), 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + // Test loading the valid config + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("Failed to load valid config: %v", err) + } + + // Verify expected values from the config + if cfg.ListenPort != 8080 { + t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort) + } + if cfg.BaseURL != "http://localhost:8000" { + t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL) + } + if cfg.TransportMode != SSETransport { + t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode) + } + if cfg.Paths.SSE != "/sse" { + t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE) + } + if cfg.Paths.Messages != "/messages" { + t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages) + } + + // Test default values + if cfg.TimeoutSeconds != 15 { + t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds) + } + if cfg.Port != 8000 { + t.Errorf("Expected default Port=8000, got %d", cfg.Port) + } +} + +func TestValidate(t *testing.T) { + tests := []struct { + name string + config Config + expectError bool + }{ + { + name: "Valid SSE config", + config: Config{ + TransportMode: SSETransport, + Paths: PathsConfig{ + SSE: "/sse", + Messages: "/messages", + }, + BaseURL: "http://localhost:8000", + }, + expectError: false, + }, + { + name: "Valid stdio config", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: true, + UserCommand: "some-command", + }, + }, + expectError: false, + }, + { + name: "Invalid stdio config - not enabled", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: false, + UserCommand: "some-command", + }, + }, + expectError: true, + }, + { + name: "Invalid stdio config - no command", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: true, + UserCommand: "", + }, + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.config.Validate() + if tc.expectError && err == nil { + t.Errorf("Expected validation error but got none") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no validation error but got: %v", err) + } + }) + } +} + +func TestGetMCPPaths(t *testing.T) { + cfg := Config{ + Paths: PathsConfig{ + SSE: "/custom-sse", + Messages: "/custom-messages", + }, + } + + paths := cfg.GetMCPPaths() + if len(paths) != 2 { + t.Errorf("Expected 2 MCP paths, got %d", len(paths)) + } + if paths[0] != "/custom-sse" { + t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) + } + if paths[1] != "/custom-messages" { + t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) + } +} + +func TestBuildExecCommand(t *testing.T) { + tests := []struct { + name string + config Config + expectedResult string + }{ + { + name: "Valid command", + config: Config{ + Stdio: StdioConfig{ + UserCommand: "test-command", + }, + Port: 8080, + BaseURL: "http://example.com", + Paths: PathsConfig{ + SSE: "/sse-path", + Messages: "/msgs", + }, + }, + expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`, + }, + { + name: "Empty command", + config: Config{ + Stdio: StdioConfig{ + UserCommand: "", + }, + }, + expectedResult: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := tc.config.BuildExecCommand() + if result != tc.expectedResult { + t.Errorf("Expected command=%s, got %s", tc.expectedResult, result) + } + }) + } +} diff --git a/internal/proxy/modifier_test.go b/internal/proxy/modifier_test.go new file mode 100644 index 0000000..3d2fd44 --- /dev/null +++ b/internal/proxy/modifier_test.go @@ -0,0 +1,147 @@ +package proxy + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +func TestAuthorizationModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/authorize": { + AddQueryParams: []config.ParamConfig{ + {Name: "client_id", Value: "test-client-id"}, + {Name: "scope", Value: "openid"}, + }, + }, + }, + }, + } + + modifier := &AuthorizationModifier{Config: cfg} + + // Create a test request + req, err := http.NewRequest("GET", "/authorize?response_type=code", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + // Check that the query parameters were added + query := modifiedReq.URL.Query() + if query.Get("client_id") != "test-client-id" { + t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id")) + } + if query.Get("scope") != "openid" { + t.Errorf("Expected scope=openid, got %s", query.Get("scope")) + } + if query.Get("response_type") != "code" { + t.Errorf("Expected response_type=code, got %s", query.Get("response_type")) + } +} + +func TestTokenModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/token": { + AddBodyParams: []config.ParamConfig{ + {Name: "audience", Value: "test-audience"}, + }, + }, + }, + }, + } + + modifier := &TokenModifier{Config: cfg} + + // Create a test request with form data + form := url.Values{} + + req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + body := make([]byte, 1024) + n, err := modifiedReq.Body.Read(body) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read body: %v", err) + } + bodyStr := string(body[:n]) + + // Parse the form data from the modified request + if err := modifiedReq.ParseForm(); err != nil { + t.Fatalf("Failed to parse form data: %v", err) + } + + // Check that the body parameters were added + if !strings.Contains(bodyStr, "audience") { + t.Errorf("Expected body to contain audience, got %s", bodyStr) + } +} + +func TestRegisterModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/register": { + AddBodyParams: []config.ParamConfig{ + {Name: "client_name", Value: "test-client"}, + }, + }, + }, + }, + } + + modifier := &RegisterModifier{Config: cfg} + + // Create a test request with JSON data + jsonBody := `{"redirect_uris":["https://example.com/callback"]}` + req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + // Read the body and check that it still contains the original data + // This test would need to be enhanced with a proper JSON parsing to verify + // the added parameters + body := make([]byte, 1024) + n, err := modifiedReq.Body.Read(body) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read body: %v", err) + } + bodyStr := string(body[:n]) + + // Simple check to see if the modified body contains the expected fields + if !strings.Contains(bodyStr, "client_name") { + t.Errorf("Expected body to contain client_name, got %s", bodyStr) + } + if !strings.Contains(bodyStr, "redirect_uris") { + t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr) + } +} diff --git a/internal/util/jwks_test.go b/internal/util/jwks_test.go new file mode 100644 index 0000000..3b00c68 --- /dev/null +++ b/internal/util/jwks_test.go @@ -0,0 +1,143 @@ +package util + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +func TestValidateJWT(t *testing.T) { + // Initialize the test JWKS data + initTestJWKS(t) + + // Test cases + tests := []struct { + name string + authHeader string + expectError bool + }{ + { + name: "Valid JWT token", + authHeader: "Bearer " + createValidJWT(t), + expectError: false, + }, + { + name: "No auth header", + authHeader: "", + expectError: true, + }, + { + name: "Invalid auth header format", + authHeader: "InvalidFormat", + expectError: true, + }, + { + name: "Invalid JWT token", + authHeader: "Bearer invalid.jwt.token", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateJWT(tc.authHeader) + if tc.expectError && err == nil { + t.Errorf("Expected error but got none") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + }) + } +} + +func TestFetchJWKS(t *testing.T) { + // Create a mock JWKS server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Generate a test RSA key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Create JWKS response + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-id", + "n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537 + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + })) + defer server.Close() + + // Test fetching JWKS + err := FetchJWKS(server.URL) + if err != nil { + t.Fatalf("FetchJWKS failed: %v", err) + } + + // Check that keys were stored + if len(publicKeys) == 0 { + t.Errorf("Expected publicKeys to be populated") + } +} + +// Helper function to initialize test JWKS data +func initTestJWKS(t *testing.T) { + // Create a test RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Initialize the publicKeys map + publicKeys = map[string]*rsa.PublicKey{ + "test-key-id": &privateKey.PublicKey, + } +} + +// Helper function to create a valid JWT token for testing +func createValidJWT(t *testing.T) string { + // Create a test RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Ensure the test key is in the publicKeys map + if publicKeys == nil { + publicKeys = map[string]*rsa.PublicKey{} + } + publicKeys["test-key-id"] = &privateKey.PublicKey + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "sub": "1234567890", + "name": "Test User", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-id" + + // Sign the token + tokenString, err := token.SignedString(privateKey) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + return tokenString +}