Skip to content

Commit 76027e3

Browse files
authored
feat: add validation for sw_url (#41)
1 parent b5cb25a commit 76027e3

File tree

8 files changed

+132
-10
lines changed

8 files changed

+132
-10
lines changed

internal/swmcp/server.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,25 @@ func WithSkyWalkingAuth(ctx context.Context, username, password string) context.
9393
// The value is sourced from the CLI/config binding for `--sw-url`,
9494
// falling back to the built-in default when unset.
9595
func configuredSkyWalkingURL() string {
96+
resolvedURL, err := resolvedConfiguredSkyWalkingURL()
97+
if err != nil {
98+
logrus.WithError(err).Warn("invalid SkyWalking OAP URL configuration; falling back to default URL")
99+
return config.DefaultSWURL
100+
}
101+
return resolvedURL
102+
}
103+
104+
func resolvedConfiguredSkyWalkingURL() (string, error) {
96105
urlStr := viper.GetString("url")
97106
if urlStr == "" {
98107
urlStr = config.DefaultSWURL
99108
}
100-
return tools.FinalizeURL(urlStr)
109+
return tools.NormalizeOAPURL(urlStr)
110+
}
111+
112+
func validateConfiguredSkyWalkingURL() error {
113+
_, err := resolvedConfiguredSkyWalkingURL()
114+
return err
101115
}
102116

103117
// resolveEnvVar resolves a value that may contain an environment variable reference

internal/swmcp/server_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323

2424
"github.com/apache/skywalking-cli/pkg/contextkey"
25+
"github.com/spf13/cobra"
2526
"github.com/spf13/viper"
2627

2728
"github.com/apache/skywalking-mcp/internal/config"
@@ -52,6 +53,43 @@ func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) {
5253
}
5354
}
5455

56+
func TestConfiguredSkyWalkingURLFallsBackToDefaultOnInvalidValue(t *testing.T) {
57+
t.Cleanup(viper.Reset)
58+
viper.Set("url", "ftp://configured-oap.example.com:12800")
59+
60+
got := configuredSkyWalkingURL()
61+
if got != config.DefaultSWURL {
62+
t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, config.DefaultSWURL)
63+
}
64+
}
65+
66+
func TestValidateConfiguredSkyWalkingURLRejectsUnsupportedScheme(t *testing.T) {
67+
t.Cleanup(viper.Reset)
68+
viper.Set("url", "ftp://configured-oap.example.com:12800")
69+
70+
err := validateConfiguredSkyWalkingURL()
71+
if err == nil {
72+
t.Fatal("validateConfiguredSkyWalkingURL() error = nil, want error")
73+
}
74+
}
75+
76+
func TestTransportCommandsRejectInvalidSWURL(t *testing.T) {
77+
t.Cleanup(viper.Reset)
78+
viper.Set("url", "ftp://configured-oap.example.com:12800")
79+
80+
for name, cmd := range map[string]*cobra.Command{
81+
"stdio": NewStdioServer(),
82+
"sse": NewSSEServer(),
83+
"streamable": NewStreamable(),
84+
} {
85+
t.Run(name, func(t *testing.T) {
86+
if err := cmd.RunE(cmd, nil); err == nil {
87+
t.Fatal("RunE() error = nil, want invalid sw-url error")
88+
}
89+
})
90+
}
91+
}
92+
5593
func TestResolveEnvVar(t *testing.T) {
5694
t.Setenv("SW_TEST_SECRET", "resolved-secret")
5795

internal/swmcp/sse.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ func NewSSEServer() *cobra.Command {
4141
Short: "Start SSE server",
4242
Long: `Start a server that listens for Server-Sent Events (SSE) on the specified address.`,
4343
RunE: func(_ *cobra.Command, _ []string) error {
44+
if err := validateConfiguredSkyWalkingURL(); err != nil {
45+
return err
46+
}
47+
4448
sseServerConfig := config.SSEServerConfig{
4549
Address: viper.GetString("sse-address"),
4650
BasePath: viper.GetString("base-path"),

internal/swmcp/stdio.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ func NewStdioServer() *cobra.Command {
4141
Short: "Start stdio server",
4242
Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`,
4343
RunE: func(_ *cobra.Command, _ []string) error {
44+
if err := validateConfiguredSkyWalkingURL(); err != nil {
45+
return err
46+
}
47+
4448
stdioServerConfig := config.StdioServerConfig{
4549
URL: viper.GetString("url"),
4650
ReadOnly: viper.GetBool("read-only"),

internal/swmcp/streamable.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ func NewStreamable() *cobra.Command {
3636
Short: "Start Streamable server",
3737
Long: `Starting SkyWalking MCP server with Streamable HTTP transport.`,
3838
RunE: func(_ *cobra.Command, _ []string) error {
39+
if err := validateConfiguredSkyWalkingURL(); err != nil {
40+
return err
41+
}
42+
3943
streamableConfig := config.StreamableServerConfig{
4044
Address: viper.GetString("address"),
4145
EndpointPath: viper.GetString("endpoint-path"),

internal/tools/common.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,37 @@ const (
4545

4646
// FinalizeURL ensures the URL ends with "/graphql".
4747
func FinalizeURL(urlStr string) string {
48-
if !strings.HasSuffix(urlStr, "/graphql") {
49-
urlStr = strings.TrimRight(urlStr, "/") + "/graphql"
48+
normalizedURL, err := NormalizeOAPURL(urlStr)
49+
if err == nil {
50+
return normalizedURL
5051
}
5152
return urlStr
5253
}
5354

54-
// validateURLScheme ensures the URL uses http or https.
55-
func validateURLScheme(rawURL string) error {
55+
// NormalizeOAPURL parses and validates the OAP URL, then ensures the path ends with /graphql.
56+
func NormalizeOAPURL(rawURL string) (string, error) {
5657
u, err := url.Parse(rawURL)
5758
if err != nil {
58-
return fmt.Errorf("invalid OAP URL: %w", err)
59+
return "", fmt.Errorf("invalid OAP URL: %w", err)
60+
}
61+
if err := validateURLScheme(u); err != nil {
62+
return "", err
63+
}
64+
if u.Host == "" {
65+
return "", fmt.Errorf("invalid OAP URL %q: host is required", rawURL)
66+
}
67+
68+
if u.Path == "" || u.Path == "/" {
69+
u.Path = "/graphql"
70+
} else if !strings.HasSuffix(u.Path, "/graphql") {
71+
u.Path = strings.TrimRight(u.Path, "/") + "/graphql"
5972
}
73+
74+
return u.String(), nil
75+
}
76+
77+
// validateURLScheme ensures the URL uses http or https.
78+
func validateURLScheme(u *url.URL) error {
6079
if u.Scheme != "http" && u.Scheme != "https" {
6180
return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme)
6281
}

internal/tools/common_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tools
1818

1919
import (
20+
"strings"
2021
"testing"
2122
"time"
2223

@@ -39,6 +40,7 @@ func TestFinalizeURL(t *testing.T) {
3940
{name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"},
4041
{name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"},
4142
{name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"},
43+
{name: "preserves query string", in: "http://localhost:12800?x=1", want: "http://localhost:12800/graphql?x=1"},
4244
}
4345

4446
for _, tc := range tests {
@@ -50,6 +52,44 @@ func TestFinalizeURL(t *testing.T) {
5052
}
5153
}
5254

55+
func TestNormalizeOAPURL(t *testing.T) {
56+
tests := []struct {
57+
name string
58+
in string
59+
want string
60+
wantErr string
61+
}{
62+
{name: "http", in: "http://localhost:12800", want: "http://localhost:12800/graphql"},
63+
{name: "https", in: "https://localhost:12800/graphql", want: "https://localhost:12800/graphql"},
64+
{name: "preserves query and fragment", in: "https://localhost:12800/oap?debug=1#frag", want: "https://localhost:12800/oap/graphql?debug=1#frag"},
65+
{name: "rejects unsupported scheme", in: "ftp://localhost:12800", wantErr: "unsupported OAP URL scheme \"ftp\""},
66+
{name: "rejects missing host", in: "http://", wantErr: "host is required"},
67+
{name: "rejects malformed hostless path", in: "http:/foo", wantErr: "host is required"},
68+
}
69+
70+
for _, tc := range tests {
71+
t.Run(tc.name, func(t *testing.T) {
72+
got, err := NormalizeOAPURL(tc.in)
73+
if tc.wantErr != "" {
74+
if err == nil {
75+
t.Fatalf("NormalizeOAPURL(%q) error = nil, want %q", tc.in, tc.wantErr)
76+
}
77+
if !strings.Contains(err.Error(), tc.wantErr) {
78+
t.Fatalf("NormalizeOAPURL(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr)
79+
}
80+
return
81+
}
82+
83+
if err != nil {
84+
t.Fatalf("NormalizeOAPURL(%q) unexpected error: %v", tc.in, err)
85+
}
86+
if got != tc.want {
87+
t.Fatalf("NormalizeOAPURL(%q) = %q, want %q", tc.in, got, tc.want)
88+
}
89+
})
90+
}
91+
}
92+
5393
func TestParseTimezoneOffset(t *testing.T) {
5494
loc, ok := parseTimezoneOffset("+0830")
5595
if !ok {

internal/tools/mqe.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,8 @@ func getContextBool(ctx context.Context, key any) bool {
9191
// executeGraphQLWithContext executes a GraphQL query using URL and auth from context.
9292
func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) {
9393
rawURL := getContextString(ctx, contextkey.BaseURL{})
94-
rawURL = FinalizeURL(rawURL)
95-
96-
if err := validateURLScheme(rawURL); err != nil {
94+
normalizedURL, err := NormalizeOAPURL(rawURL)
95+
if err != nil {
9796
return nil, err
9897
}
9998

@@ -107,7 +106,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[
107106
return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err)
108107
}
109108

110-
req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData))
109+
req, err := http.NewRequestWithContext(ctx, "POST", normalizedURL, bytes.NewBuffer(jsonData))
111110
if err != nil {
112111
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
113112
}

0 commit comments

Comments
 (0)