diff --git a/testdata/localstack-init.sh b/testdata/localstack-init.sh index ad88f32f..7dcd6a7d 100755 --- a/testdata/localstack-init.sh +++ b/testdata/localstack-init.sh @@ -40,4 +40,28 @@ awslocal cloudwatch put-metric-data \ --value 1000000 \ --unit Bytes +# CloudWatch Logs test data +echo "Creating CloudWatch Logs test data..." + +awslocal logs create-log-group \ + --log-group-name "test-application-logs" \ + --region us-east-1 + +awslocal logs create-log-stream \ + --log-group-name "test-application-logs" \ + --log-stream-name "test-stream-1" \ + --region us-east-1 + +TIMESTAMP=$(date +%s000) +awslocal logs put-log-events \ + --log-group-name "test-application-logs" \ + --log-stream-name "test-stream-1" \ + --region us-east-1 \ + --log-events \ + "[{\"timestamp\":${TIMESTAMP},\"message\":\"ERROR: Connection timeout in service handler\"}, \ + {\"timestamp\":$((TIMESTAMP+1000)),\"message\":\"INFO: Request processed successfully\"}, \ + {\"timestamp\":$((TIMESTAMP+2000)),\"message\":\"WARN: High memory usage detected\"}, \ + {\"timestamp\":$((TIMESTAMP+3000)),\"message\":\"ERROR: Database query failed\"}, \ + {\"timestamp\":$((TIMESTAMP+4000)),\"message\":\"INFO: Health check passed\"}]" + echo "CloudWatch test data seeded successfully" diff --git a/tests/cloudwatch_logs_test.py b/tests/cloudwatch_logs_test.py new file mode 100644 index 00000000..8142f5a1 --- /dev/null +++ b/tests/cloudwatch_logs_test.py @@ -0,0 +1,61 @@ +import pytest +from mcp import ClientSession + +from conftest import models +from utils import assert_mcp_eval, run_llm_tool_loop + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("model", models) +@pytest.mark.flaky(reruns=2) +async def test_cloudwatch_list_log_groups( + model: str, + mcp_client: ClientSession, + mcp_transport: str, +): + """Test that the LLM can list CloudWatch log groups.""" + prompt = "List all CloudWatch log groups available on the CloudWatch datasource in Grafana. Use the us-east-1 region." + final_content, tools_called, mcp_server = await run_llm_tool_loop( + model, mcp_client, mcp_transport, prompt + ) + + assert_mcp_eval( + prompt, + final_content, + tools_called, + mcp_server, + "Does the response contain CloudWatch log group names? " + "It should mention specific log groups like 'test-application-logs' " + "or similar log group patterns. ", + expected_tools="list_cloudwatch_log_groups", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.flaky(reruns=2) +async def test_cloudwatch_query_logs( + model: str, + mcp_client: ClientSession, + mcp_transport: str, +): + """Test that the LLM can query CloudWatch Logs Insights.""" + prompt = ( + "Query CloudWatch Logs Insights for ERROR messages in the 'test-application-logs' log group " + "over the last hour. Use the us-east-1 region." + ) + final_content, tools_called, mcp_server = await run_llm_tool_loop( + model, mcp_client, mcp_transport, prompt + ) + + assert_mcp_eval( + prompt, + final_content, + tools_called, + mcp_server, + "Does the response provide information about CloudWatch log data? " + "It should either show log entries or messages, mention that logs were retrieved, " + "or explain that no log data was found in the specified time range. " + "Generic error messages don't count.", + expected_tools="query_cloudwatch_logs", + ) diff --git a/tools/cloudwatch.go b/tools/cloudwatch.go index 8de95fc4..f884fb8b 100644 --- a/tools/cloudwatch.go +++ b/tools/cloudwatch.go @@ -48,14 +48,45 @@ type CloudWatchQueryResult struct { Hints []string `json:"hints,omitempty"` } +// cloudWatchFrameMeta represents metadata in a Grafana data frame schema, +// such as query execution status (e.g. "Complete", "Running"). +// The Custom field can be either a JSON object (Logs Insights) or a string (Metrics), +// so it uses a custom unmarshaler to handle both cases. +type cloudWatchFrameMeta struct { + Custom cloudWatchCustomMeta `json:"custom,omitempty"` +} + +// cloudWatchCustomMeta handles the polymorphic "custom" field in frame metadata. +// For Logs Insights responses it's an object like {"Status":"Complete"}. +// For Metrics responses it can be a plain string like "timeSeriesQuery". +type cloudWatchCustomMeta struct { + Status string `json:"Status"` +} + +func (m *cloudWatchCustomMeta) UnmarshalJSON(data []byte) error { + // Try object first: {"Status":"Complete"} + type plain cloudWatchCustomMeta + if err := json.Unmarshal(data, (*plain)(m)); err == nil { + return nil + } + // Fall back to string (metrics responses) — ignore it, no Status to extract + var s string + if err := json.Unmarshal(data, &s); err == nil { + return nil + } + // Unknown format — silently ignore to avoid breaking existing functionality + return nil +} + // cloudWatchQueryResponse represents the raw API response from Grafana's /api/ds/query type cloudWatchQueryResponse struct { Results map[string]struct { Status int `json:"status,omitempty"` Frames []struct { Schema struct { - Name string `json:"name,omitempty"` - RefID string `json:"refId,omitempty"` + Name string `json:"name,omitempty"` + RefID string `json:"refId,omitempty"` + Meta *cloudWatchFrameMeta `json:"meta,omitempty"` Fields []struct { Name string `json:"name"` Type string `json:"type"` @@ -95,16 +126,12 @@ func newCloudWatchClient(ctx context.Context, uid string) (*cloudWatchClient, er cfg := mcpgrafana.GrafanaConfigFromContext(ctx) baseURL := strings.TrimRight(cfg.URL, "/") - // Create custom transport with TLS configuration if available - var transport = http.DefaultTransport - if tlsConfig := cfg.TLSConfig; tlsConfig != nil { - var err error - transport, err = tlsConfig.HTTPTransport(transport.(*http.Transport)) - if err != nil { - return nil, fmt.Errorf("failed to create custom transport: %w", err) - } + // CloudWatch uses /api/ds/query and /api/datasources/uid/.../resources/ + // (not proxy paths), so no fallback transport is needed. + transport, err := mcpgrafana.BuildTransport(&cfg, nil) + if err != nil { + return nil, fmt.Errorf("failed to create transport: %w", err) } - transport = NewAuthRoundTripper(transport, cfg.AccessToken, cfg.IDToken, cfg.APIKey, cfg.BasicAuth) transport = mcpgrafana.NewOrgIDRoundTripper(transport, cfg.OrgID) @@ -118,7 +145,48 @@ func newCloudWatchClient(ctx context.Context, uid string) (*cloudWatchClient, er }, nil } -// query executes a CloudWatch query via Grafana's /api/ds/query endpoint +// postDsQuery sends a POST request to /api/ds/query and returns the parsed response. +// Shared by both metrics (query) and logs (startLogsQuery, getLogsQueryResults). +func (c *cloudWatchClient) postDsQuery(ctx context.Context, payload map[string]interface{}) (*cloudWatchQueryResponse, error) { + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshaling query payload: %w", err) + } + + reqURL := c.baseURL + "/api/ds/query" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("executing request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("query returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var bytesLimit int64 = 1024 * 1024 * 10 // 10MB limit + body := io.LimitReader(resp.Body, bytesLimit) + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + + var queryResp cloudWatchQueryResponse + if err := unmarshalJSONWithLimitMsg(bodyBytes, &queryResp, int(bytesLimit)); err != nil { + return nil, err + } + + return &queryResp, nil +} + +// query executes a CloudWatch metrics query via Grafana's /api/ds/query endpoint func (c *cloudWatchClient) query(ctx context.Context, args CloudWatchQueryParams, from, to time.Time) (*cloudWatchQueryResponse, error) { // Format dimensions for CloudWatch query // CloudWatch expects dimensions as map[string][]string @@ -170,17 +238,22 @@ func (c *cloudWatchClient) query(ctx context.Context, args CloudWatchQueryParams "to": strconv.FormatInt(to.UnixMilli(), 10), } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("marshaling query payload: %w", err) + return c.postDsQuery(ctx, payload) +} + +// fetchCloudWatchResource performs a GET request against the datasource resource API +// and returns the raw response body. The resourcePath is the suffix after +// /api/datasources/uid/{uid}/resources/ (e.g. "namespaces", "log-groups"). +func (c *cloudWatchClient) fetchCloudWatchResource(ctx context.Context, dsUID, resourcePath string, params url.Values) ([]byte, error) { + resourceURL := c.baseURL + "/api/datasources/uid/" + dsUID + "/resources/" + resourcePath + if len(params) > 0 { + resourceURL += "?" + params.Encode() } - url := c.baseURL + "/api/ds/query" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payloadBytes)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } - req.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(req) if err != nil { @@ -190,23 +263,12 @@ func (c *cloudWatchClient) query(ctx context.Context, args CloudWatchQueryParams if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("CloudWatch query returned status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - // Limit size of response read - var bytesLimit int64 = 1024 * 1024 * 10 // 10MB limit - body := io.LimitReader(resp.Body, bytesLimit) - bodyBytes, err := io.ReadAll(body) - if err != nil { - return nil, fmt.Errorf("reading response body: %w", err) + return nil, fmt.Errorf("CloudWatch %s returned status %d: %s", resourcePath, resp.StatusCode, string(bodyBytes)) } - var queryResp cloudWatchQueryResponse - if err := unmarshalJSONWithLimitMsg(bodyBytes, &queryResp, int(bytesLimit)); err != nil { - return nil, err - } - - return &queryResp, nil + bytesLimit := 1024 * 1024 // 1MB limit + body := io.LimitReader(resp.Body, int64(bytesLimit)) + return io.ReadAll(body) } // queryCloudWatch executes a CloudWatch query via Grafana @@ -448,7 +510,6 @@ func listCloudWatchNamespaces(ctx context.Context, args ListCloudWatchNamespaces return nil, fmt.Errorf("creating CloudWatch client: %w", err) } - // Build query parameters params := url.Values{} if args.Region != "" { params.Set("region", args.Region) @@ -457,34 +518,11 @@ func listCloudWatchNamespaces(ctx context.Context, args ListCloudWatchNamespaces params.Set("accountId", args.AccountId) } - resourceURL := client.baseURL + "/api/datasources/uid/" + args.DatasourceUID + "/resources/namespaces" - if len(params) > 0 { - resourceURL += "?" + params.Encode() - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + body, err := client.fetchCloudWatchResource(ctx, args.DatasourceUID, "namespaces", params) if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - resp, err := client.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("executing request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("CloudWatch namespaces returned status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - bytesLimit := 1024 * 1024 // 1MB limit - body := io.LimitReader(resp.Body, int64(bytesLimit)) - bodyBytes, err := io.ReadAll(body) - if err != nil { - return nil, fmt.Errorf("reading response body: %w", err) + return nil, err } - - return parseCloudWatchResourceResponse(bodyBytes, bytesLimit) + return parseCloudWatchResourceResponse(body, 1024*1024) } // ListCloudWatchNamespaces is a tool for listing CloudWatch namespaces @@ -512,7 +550,6 @@ func listCloudWatchMetrics(ctx context.Context, args ListCloudWatchMetricsParams return nil, fmt.Errorf("creating CloudWatch client: %w", err) } - // Build query parameters params := url.Values{} params.Set("namespace", args.Namespace) if args.Region != "" { @@ -522,31 +559,11 @@ func listCloudWatchMetrics(ctx context.Context, args ListCloudWatchMetricsParams params.Set("accountId", args.AccountId) } - resourceURL := client.baseURL + "/api/datasources/uid/" + args.DatasourceUID + "/resources/metrics?" + params.Encode() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - resp, err := client.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("executing request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("CloudWatch metrics returned status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - bytesLimit := 1024 * 1024 // 1MB limit - body := io.LimitReader(resp.Body, int64(bytesLimit)) - bodyBytes, err := io.ReadAll(body) + body, err := client.fetchCloudWatchResource(ctx, args.DatasourceUID, "metrics", params) if err != nil { - return nil, fmt.Errorf("reading response body: %w", err) + return nil, err } - - return parseCloudWatchMetricsResponse(bodyBytes, bytesLimit) + return parseCloudWatchMetricsResponse(body, 1024*1024) } // ListCloudWatchMetrics is a tool for listing CloudWatch metrics @@ -575,7 +592,6 @@ func listCloudWatchDimensions(ctx context.Context, args ListCloudWatchDimensions return nil, fmt.Errorf("creating CloudWatch client: %w", err) } - // Build query parameters params := url.Values{} params.Set("namespace", args.Namespace) params.Set("metricName", args.MetricName) @@ -586,31 +602,11 @@ func listCloudWatchDimensions(ctx context.Context, args ListCloudWatchDimensions params.Set("accountId", args.AccountId) } - resourceURL := client.baseURL + "/api/datasources/uid/" + args.DatasourceUID + "/resources/dimension-keys?" + params.Encode() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + body, err := client.fetchCloudWatchResource(ctx, args.DatasourceUID, "dimension-keys", params) if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - resp, err := client.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("executing request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("CloudWatch dimensions returned status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - bytesLimit := 1024 * 1024 // 1MB Limit - body := io.LimitReader(resp.Body, int64(bytesLimit)) - bodyBytes, err := io.ReadAll(body) - if err != nil { - return nil, fmt.Errorf("reading response body: %w", err) + return nil, err } - - return parseCloudWatchResourceResponse(bodyBytes, bytesLimit) + return parseCloudWatchResourceResponse(body, 1024*1024) } // ListCloudWatchDimensions is a tool for listing CloudWatch dimension keys @@ -625,8 +621,13 @@ var ListCloudWatchDimensions = mcpgrafana.MustTool( // AddCloudWatchTools registers all CloudWatch tools with the MCP server func AddCloudWatchTools(mcp *server.MCPServer) { + // Metrics tools QueryCloudWatch.Register(mcp) ListCloudWatchNamespaces.Register(mcp) ListCloudWatchMetrics.Register(mcp) ListCloudWatchDimensions.Register(mcp) + // Logs tools + QueryCloudWatchLogs.Register(mcp) + ListCloudWatchLogGroups.Register(mcp) + ListCloudWatchLogGroupFields.Register(mcp) } diff --git a/tools/cloudwatch_logs.go b/tools/cloudwatch_logs.go new file mode 100644 index 00000000..61c4bbd8 --- /dev/null +++ b/tools/cloudwatch_logs.go @@ -0,0 +1,523 @@ +package tools + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + mcpgrafana "github.com/grafana/mcp-grafana" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + // DefaultCloudWatchLogsLimit is the default number of log entries to return + DefaultCloudWatchLogsLimit = 100 + + // MaxCloudWatchLogsLimit is the maximum number of log entries that can be requested + MaxCloudWatchLogsLimit = 1000 + + // defaultLogsQueryTimeout is the maximum time to wait for a Logs Insights query to complete + defaultLogsQueryTimeout = 30 * time.Second + + // initialPollInterval is the starting interval for polling GetQueryResults + initialPollInterval = 200 * time.Millisecond + + // maxPollInterval is the maximum interval between polling attempts + maxPollInterval = 2 * time.Second + + // pollBackoffMultiplier is the multiplier for exponential backoff + pollBackoffMultiplier = 1.5 +) + +// QueryCloudWatchLogsParams defines the parameters for a CloudWatch Logs Insights query +type QueryCloudWatchLogsParams struct { + DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the CloudWatch datasource. Use list_datasources to find available UIDs."` + Region string `json:"region" jsonschema:"required,description=AWS region (e.g. us-east-1)"` + LogGroupNames []string `json:"logGroupNames" jsonschema:"required,description=List of log group names to query (e.g. [\"cloudwatch-prod\"\\, \"/aws/lambda/my-function\"]). Use list_cloudwatch_log_groups to discover available groups."` + QueryString string `json:"queryString" jsonschema:"required,description=CloudWatch Logs Insights query string. Example: 'fields @timestamp\\, @message | filter @message like /ERROR/ | sort @timestamp desc | limit 20'"` + Start string `json:"start,omitempty" jsonschema:"description=Start time. Formats: 'now-1h'\\, '2026-02-02T19:00:00Z'\\, '1738519200000' (Unix ms). Default: now-1h"` + End string `json:"end,omitempty" jsonschema:"description=End time. Formats: 'now'\\, '2026-02-02T20:00:00Z'\\, '1738522800000' (Unix ms). Default: now"` + Limit int `json:"limit,omitempty" jsonschema:"description=Maximum number of log entries to return (default: 100\\, max: 1000). Note: this is applied to the result; include a 'limit' clause in your query for server-side limiting."` +} + +// CloudWatchLogsQueryResult represents the result of a CloudWatch Logs Insights query +type CloudWatchLogsQueryResult struct { + Logs []CloudWatchLogEntry `json:"logs"` + Query string `json:"query"` + TotalFound int `json:"totalFound"` + Status string `json:"status"` + Hints []string `json:"hints,omitempty"` +} + +// CloudWatchLogEntry represents a single log entry returned by Logs Insights. +// Fields are dynamic based on the query (e.g. @timestamp, @message, custom fields). +type CloudWatchLogEntry struct { + Fields map[string]string `json:"fields"` +} + +// ListCloudWatchLogGroupsParams defines the parameters for listing CloudWatch log groups +type ListCloudWatchLogGroupsParams struct { + DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the CloudWatch datasource"` + Region string `json:"region" jsonschema:"required,description=AWS region (e.g. us-east-1)"` + Pattern string `json:"pattern,omitempty" jsonschema:"description=Optional pattern to filter log group names (prefix match)"` + AccountId string `json:"accountId,omitempty" jsonschema:"description=AWS account ID for cross-account monitoring."` +} + +// ListCloudWatchLogGroupFieldsParams defines the parameters for listing fields in a log group +type ListCloudWatchLogGroupFieldsParams struct { + DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the CloudWatch datasource"` + Region string `json:"region" jsonschema:"required,description=AWS region (e.g. us-east-1)"` + LogGroupName string `json:"logGroupName" jsonschema:"required,description=The log group name to discover fields for"` + AccountId string `json:"accountId,omitempty" jsonschema:"description=AWS account ID for cross-account monitoring."` +} + +// cloudWatchLogGroupItem represents a log group returned by the log-groups resource API. +// Response format: [{"value": {"arn": "...", "name": "..."}, "accountId": "..."}] +type cloudWatchLogGroupItem struct { + Value struct { + ARN string `json:"arn"` + Name string `json:"name"` + } `json:"value"` + AccountId string `json:"accountId,omitempty"` +} + +// cloudWatchLogGroupFieldItem represents a field returned by the log-group-fields resource API. +// Response format: [{"value": {"name": "...", "percent": 50}, "accountId": "..."}] +type cloudWatchLogGroupFieldItem struct { + Value struct { + Name string `json:"name"` + Percent int64 `json:"percent"` + } `json:"value"` +} + +// parseCloudWatchLogGroupsResponse extracts log group names from the resource API response +func parseCloudWatchLogGroupsResponse(bodyBytes []byte, bytesLimit int) ([]string, error) { + var items []cloudWatchLogGroupItem + if err := unmarshalJSONWithLimitMsg(bodyBytes, &items, bytesLimit); err != nil { + return nil, err + } + + result := make([]string, len(items)) + for i, item := range items { + result[i] = item.Value.Name + } + return result, nil +} + +// parseCloudWatchLogGroupFieldsResponse extracts field names from the resource API response +func parseCloudWatchLogGroupFieldsResponse(bodyBytes []byte, bytesLimit int) ([]string, error) { + var items []cloudWatchLogGroupFieldItem + if err := unmarshalJSONWithLimitMsg(bodyBytes, &items, bytesLimit); err != nil { + return nil, err + } + + result := make([]string, len(items)) + for i, item := range items { + result[i] = item.Value.Name + } + return result, nil +} + +// enforceCloudWatchLogsLimit ensures a limit value is within acceptable bounds +func enforceCloudWatchLogsLimit(requestedLimit int) int { + if requestedLimit <= 0 { + return DefaultCloudWatchLogsLimit + } + if requestedLimit > MaxCloudWatchLogsLimit { + return MaxCloudWatchLogsLimit + } + return requestedLimit +} + +// listCloudWatchLogGroups lists available CloudWatch log groups via the resource API +func listCloudWatchLogGroups(ctx context.Context, args ListCloudWatchLogGroupsParams) ([]string, error) { + client, err := newCloudWatchClient(ctx, args.DatasourceUID) + if err != nil { + return nil, fmt.Errorf("creating CloudWatch client: %w", err) + } + + params := url.Values{} + if args.Region != "" { + params.Set("region", args.Region) + } + if args.Pattern != "" { + params.Set("logGroupNamePrefix", args.Pattern) + } + if args.AccountId != "" { + params.Set("accountId", args.AccountId) + } + + body, err := client.fetchCloudWatchResource(ctx, args.DatasourceUID, "log-groups", params) + if err != nil { + return nil, err + } + return parseCloudWatchLogGroupsResponse(body, 1024*1024) +} + +// listCloudWatchLogGroupFields lists discovered fields for a CloudWatch log group +func listCloudWatchLogGroupFields(ctx context.Context, args ListCloudWatchLogGroupFieldsParams) ([]string, error) { + client, err := newCloudWatchClient(ctx, args.DatasourceUID) + if err != nil { + return nil, fmt.Errorf("creating CloudWatch client: %w", err) + } + + params := url.Values{} + params.Set("logGroupName", args.LogGroupName) + if args.Region != "" { + params.Set("region", args.Region) + } + if args.AccountId != "" { + params.Set("accountId", args.AccountId) + } + + body, err := client.fetchCloudWatchResource(ctx, args.DatasourceUID, "log-group-fields", params) + if err != nil { + return nil, err + } + return parseCloudWatchLogGroupFieldsResponse(body, 1024*1024) +} + +// queryCloudWatchLogs executes a CloudWatch Logs Insights query via Grafana. +// It handles the async StartQuery → poll GetQueryResults flow internally. +func queryCloudWatchLogs(ctx context.Context, args QueryCloudWatchLogsParams) (*CloudWatchLogsQueryResult, error) { + client, err := newCloudWatchClient(ctx, args.DatasourceUID) + if err != nil { + return nil, fmt.Errorf("creating CloudWatch client: %w", err) + } + + // Parse time range + now := time.Now() + fromTime := now.Add(-1 * time.Hour) // Default: 1 hour ago + toTime := now // Default: now + + if args.Start != "" { + parsed, err := parseStartTime(args.Start) + if err != nil { + return nil, fmt.Errorf("parsing start time: %w", err) + } + if !parsed.IsZero() { + fromTime = parsed + } + } + + if args.End != "" { + parsed, err := parseEndTime(args.End) + if err != nil { + return nil, fmt.Errorf("parsing end time: %w", err) + } + if !parsed.IsZero() { + toTime = parsed + } + } + + // Step 1: Start the query + queryID, err := client.startLogsQuery(ctx, args, fromTime, toTime) + if err != nil { + return nil, fmt.Errorf("starting CloudWatch Logs query: %w", err) + } + + // Step 2: Poll for results + resp, err := client.pollLogsQueryResults(ctx, args.DatasourceUID, queryID, args.Region, fromTime, toTime, defaultLogsQueryTimeout) + if err != nil { + return nil, err + } + + // Step 3: Parse results + limit := enforceCloudWatchLogsLimit(args.Limit) + result, err := parseLogsQueryResponse(resp, args.QueryString, limit) + if err != nil { + return nil, err + } + + return result, nil +} + +// startLogsQuery sends the StartQuery request and extracts the queryId from the response +func (c *cloudWatchClient) startLogsQuery(ctx context.Context, args QueryCloudWatchLogsParams, from, to time.Time) (string, error) { + query := map[string]interface{}{ + "datasource": map[string]string{ + "uid": args.DatasourceUID, + "type": CloudWatchDatasourceType, + }, + "refId": "A", + "type": "logAction", + "subtype": "StartQuery", + "queryMode": "Logs", + "region": args.Region, + "queryString": args.QueryString, + "logGroupNames": args.LogGroupNames, + "id": "", + "intervalMs": 1, + "maxDataPoints": 1, + } + + payload := map[string]interface{}{ + "queries": []map[string]interface{}{query}, + "from": strconv.FormatInt(from.UnixMilli(), 10), + "to": strconv.FormatInt(to.UnixMilli(), 10), + } + + resp, err := c.postDsQuery(ctx, payload) + if err != nil { + return "", err + } + + // Extract queryId from the response + queryID, err := extractQueryID(resp) + if err != nil { + return "", fmt.Errorf("extracting queryId from StartQuery response: %w", err) + } + + return queryID, nil +} + +// pollLogsQueryResults polls GetQueryResults with exponential backoff until complete or timeout +func (c *cloudWatchClient) pollLogsQueryResults(ctx context.Context, dsUID, queryID, region string, from, to time.Time, timeout time.Duration) (*cloudWatchQueryResponse, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + interval := initialPollInterval + + for { + resp, err := c.getLogsQueryResults(ctx, dsUID, queryID, region, from, to) + if err != nil { + return nil, fmt.Errorf("polling CloudWatch Logs query results: %w", err) + } + + status := extractQueryStatus(resp) + switch status { + case "Complete": + return resp, nil + case "Running", "Scheduled", "": + // Continue polling + default: + return nil, fmt.Errorf("CloudWatch Logs query failed with status: %s", status) + } + + // Backoff with context/timeout cancellation support + timer := time.NewTimer(interval) + select { + case <-ctx.Done(): + timer.Stop() + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("CloudWatch Logs query timed out after %s (last status: %s)", timeout, status) + } + return nil, ctx.Err() + case <-timer.C: + } + + interval = time.Duration(float64(interval) * pollBackoffMultiplier) + if interval > maxPollInterval { + interval = maxPollInterval + } + } +} + +// getLogsQueryResults sends a GetQueryResults request for a given queryId +func (c *cloudWatchClient) getLogsQueryResults(ctx context.Context, dsUID, queryID, region string, from, to time.Time) (*cloudWatchQueryResponse, error) { + query := map[string]interface{}{ + "datasource": map[string]string{ + "uid": dsUID, + "type": CloudWatchDatasourceType, + }, + "refId": "A", + "type": "logAction", + "subtype": "GetQueryResults", + "queryMode": "Logs", + "region": region, + "queryId": queryID, + "id": "", + "intervalMs": 1, + } + + payload := map[string]interface{}{ + "queries": []map[string]interface{}{query}, + "from": strconv.FormatInt(from.UnixMilli(), 10), + "to": strconv.FormatInt(to.UnixMilli(), 10), + } + + return c.postDsQuery(ctx, payload) +} + +// extractQueryID extracts the queryId from a StartQuery response. +// The Grafana CloudWatch plugin returns the queryId in the first frame's data. +func extractQueryID(resp *cloudWatchQueryResponse) (string, error) { + for _, r := range resp.Results { + if r.Error != "" { + return "", fmt.Errorf("query error: %s", r.Error) + } + + for _, frame := range r.Frames { + // Look for a field named "queryId" in the schema + for i, field := range frame.Schema.Fields { + if field.Name == "queryId" && i < len(frame.Data.Values) { + if len(frame.Data.Values[i]) > 0 { + if qid, ok := frame.Data.Values[i][0].(string); ok && qid != "" { + return qid, nil + } + } + } + } + + } + } + + return "", fmt.Errorf("no queryId found in StartQuery response") +} + +// extractQueryStatus extracts the query status from the response frame metadata +func extractQueryStatus(resp *cloudWatchQueryResponse) string { + for _, r := range resp.Results { + for _, frame := range r.Frames { + if frame.Schema.Meta != nil { + return frame.Schema.Meta.Custom.Status + } + } + } + return "" +} + +// parseLogsQueryResponse converts the raw Grafana data frame response to CloudWatchLogsQueryResult. +// The response contains columnar data: schema.fields defines column names, data.values contains parallel arrays. +func parseLogsQueryResponse(resp *cloudWatchQueryResponse, query string, limit int) (*CloudWatchLogsQueryResult, error) { + result := &CloudWatchLogsQueryResult{ + Query: query, + Logs: []CloudWatchLogEntry{}, + } + + for refID, r := range resp.Results { + if r.Error != "" { + return nil, fmt.Errorf("query error (refId=%s): %s", refID, r.Error) + } + + // Extract status from first frame's meta + if len(r.Frames) > 0 && r.Frames[0].Schema.Meta != nil { + result.Status = r.Frames[0].Schema.Meta.Custom.Status + } + + for _, frame := range r.Frames { + fieldNames := make([]string, len(frame.Schema.Fields)) + for i, f := range frame.Schema.Fields { + fieldNames[i] = f.Name + } + + if len(frame.Data.Values) == 0 || len(fieldNames) == 0 { + continue + } + + // Determine row count from first column + rowCount := len(frame.Data.Values[0]) + + for row := 0; row < rowCount; row++ { + if len(result.Logs) >= limit { + break + } + + entry := CloudWatchLogEntry{ + Fields: make(map[string]string), + } + for col := 0; col < len(fieldNames) && col < len(frame.Data.Values); col++ { + if row < len(frame.Data.Values[col]) { + val := frame.Data.Values[col][row] + // Skip Grafana-internal metadata fields + if fieldNames[col] == "@ptr" || strings.HasSuffix(fieldNames[col], "__grafana_internal__") { + continue + } + entry.Fields[fieldNames[col]] = formatLogValue(val, frame.Schema.Fields[col].Type) + } + } + result.Logs = append(result.Logs, entry) + } + } + } + + result.TotalFound = len(result.Logs) + + if len(result.Logs) == 0 { + result.Hints = generateCloudWatchLogsEmptyResultHints() + } + + return result, nil +} + +// formatLogValue converts a raw interface{} value to a string for display. +// It handles timestamps (float64 ms → RFC3339), strings, and nil values. +func formatLogValue(v interface{}, fieldType string) string { + switch val := v.(type) { + case string: + return val + case float64: + if fieldType == "time" { + return time.UnixMilli(int64(val)).UTC().Format(time.RFC3339Nano) + } + return strconv.FormatFloat(val, 'f', -1, 64) + case int64: + if fieldType == "time" { + return time.UnixMilli(val).UTC().Format(time.RFC3339Nano) + } + return strconv.FormatInt(val, 10) + case nil: + return "" + default: + return fmt.Sprintf("%v", val) + } +} + +// generateCloudWatchLogsEmptyResultHints generates helpful hints when a Logs Insights query returns no data +func generateCloudWatchLogsEmptyResultHints() []string { + return []string{ + "No log data found. Possible reasons:", + "- Log group name may be incorrect - use list_cloudwatch_log_groups to discover available groups", + "- Query syntax may be invalid - check CloudWatch Logs Insights query syntax", + "- Filter may be too restrictive - try a broader filter or remove it", + "- Time range may have no log events - try extending with start=\"now-6h\"", + "- Region may be incorrect - verify the log group exists in the specified region", + "- Use list_cloudwatch_log_group_fields to discover available fields for the log group", + } +} + +// Tool definitions + +// QueryCloudWatchLogs is a tool for executing CloudWatch Logs Insights queries via Grafana +var QueryCloudWatchLogs = mcpgrafana.MustTool( + "query_cloudwatch_logs", + `Execute a CloudWatch Logs Insights query via Grafana. Requires region and at least one log group. + +REQUIRED FIRST: Use list_cloudwatch_log_groups -> list_cloudwatch_log_group_fields -> then query. + +The query uses CloudWatch Logs Insights syntax: +- fields @timestamp, @message | sort @timestamp desc | limit 20 +- filter @message like /error/i | stats count() by bin(5m) +- fields @timestamp, @message, @logStream | filter @message like /exception/ + +Time formats: 'now-1h', '2026-02-02T19:00:00Z', '1738519200000' (Unix ms) + +Cross-account monitoring: Use the accountId parameter in list tools to discover log groups from linked accounts.`, + queryCloudWatchLogs, + mcp.WithTitleAnnotation("Query CloudWatch Logs"), + mcp.WithIdempotentHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), +) + +// ListCloudWatchLogGroups is a tool for listing CloudWatch log groups +var ListCloudWatchLogGroups = mcpgrafana.MustTool( + "list_cloudwatch_log_groups", + "START HERE for CloudWatch Logs: List available log groups. Requires region. Supports filtering by prefix pattern and cross-account monitoring via optional accountId. NEXT: Use list_cloudwatch_log_group_fields, then query_cloudwatch_logs.", + listCloudWatchLogGroups, + mcp.WithTitleAnnotation("List CloudWatch log groups"), + mcp.WithIdempotentHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), +) + +// ListCloudWatchLogGroupFields is a tool for listing fields in a CloudWatch log group +var ListCloudWatchLogGroupFields = mcpgrafana.MustTool( + "list_cloudwatch_log_group_fields", + "List discovered fields for a CloudWatch log group. Use after list_cloudwatch_log_groups to find available fields for querying. Requires region. NEXT: Use query_cloudwatch_logs with the discovered fields.", + listCloudWatchLogGroupFields, + mcp.WithTitleAnnotation("List CloudWatch log group fields"), + mcp.WithIdempotentHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), +) diff --git a/tools/cloudwatch_logs_integration_test.go b/tools/cloudwatch_logs_integration_test.go new file mode 100644 index 00000000..b9828429 --- /dev/null +++ b/tools/cloudwatch_logs_integration_test.go @@ -0,0 +1,97 @@ +//go:build integration + +package tools + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCloudWatchLogsIntegration_ListLogGroups(t *testing.T) { + ctx := newTestContext() + + result, err := listCloudWatchLogGroups(ctx, ListCloudWatchLogGroupsParams{ + DatasourceUID: cloudwatchTestDatasourceUID, + Region: "us-east-1", + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.GreaterOrEqual(t, len(result), 1, "Should find at least one log group") +} + +func TestCloudWatchLogsIntegration_ListLogGroupFields(t *testing.T) { + ctx := newTestContext() + + result, err := listCloudWatchLogGroupFields(ctx, ListCloudWatchLogGroupFieldsParams{ + DatasourceUID: cloudwatchTestDatasourceUID, + Region: "us-east-1", + LogGroupName: "test-application-logs", + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.GreaterOrEqual(t, len(result), 0) +} + +func TestCloudWatchLogsIntegration_QueryLogs(t *testing.T) { + ctx := newTestContext() + + result, err := queryCloudWatchLogs(ctx, QueryCloudWatchLogsParams{ + DatasourceUID: cloudwatchTestDatasourceUID, + Region: "us-east-1", + LogGroupNames: []string{"test-application-logs"}, + QueryString: "fields @timestamp, @message | sort @timestamp desc | limit 5", + Start: "now-1h", + End: "now", + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.NotNil(t, result.Logs) +} + +func TestCloudWatchLogsIntegration_QueryEmptyResult(t *testing.T) { + ctx := newTestContext() + + result, err := queryCloudWatchLogs(ctx, QueryCloudWatchLogsParams{ + DatasourceUID: cloudwatchTestDatasourceUID, + Region: "us-east-1", + LogGroupNames: []string{"nonexistent-log-group"}, + QueryString: "fields @timestamp, @message | limit 5", + Start: "now-1h", + End: "now", + }) + + if err == nil { + require.NotNil(t, result) + if len(result.Logs) == 0 { + assert.NotEmpty(t, result.Hints, "Empty result should have hints") + } + } +} + +func TestCloudWatchLogsIntegration_InvalidDatasource(t *testing.T) { + ctx := newTestContext() + + _, err := listCloudWatchLogGroups(ctx, ListCloudWatchLogGroupsParams{ + DatasourceUID: "nonexistent-uid", + Region: "us-east-1", + }) + + require.Error(t, err) +} + +func TestCloudWatchLogsIntegration_WrongDatasourceType(t *testing.T) { + ctx := newTestContext() + + _, err := listCloudWatchLogGroups(ctx, ListCloudWatchLogGroupsParams{ + DatasourceUID: "prometheus", + Region: "us-east-1", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "not cloudwatch") +} diff --git a/tools/cloudwatch_logs_test.go b/tools/cloudwatch_logs_test.go new file mode 100644 index 00000000..d0852ac3 --- /dev/null +++ b/tools/cloudwatch_logs_test.go @@ -0,0 +1,445 @@ +//go:build unit + +package tools + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnforceCloudWatchLogsLimit(t *testing.T) { + tests := []struct { + name string + input int + expected int + }{ + {name: "zero returns default", input: 0, expected: DefaultCloudWatchLogsLimit}, + {name: "negative returns default", input: -1, expected: DefaultCloudWatchLogsLimit}, + {name: "within range", input: 50, expected: 50}, + {name: "exactly default", input: 100, expected: 100}, + {name: "exactly max", input: MaxCloudWatchLogsLimit, expected: MaxCloudWatchLogsLimit}, + {name: "exceeds max", input: 5000, expected: MaxCloudWatchLogsLimit}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, enforceCloudWatchLogsLimit(tt.input)) + }) + } +} + +func TestParseCloudWatchLogGroupsResponse(t *testing.T) { + tests := []struct { + name string + input string + expected []string + expectError bool + }{ + { + name: "valid single log group", + input: `[{"value":{"arn":"arn:aws:logs:us-east-1:123:log-group:/ecs/core-prod","name":"/ecs/core-prod"},"accountId":"123"}]`, + expected: []string{"/ecs/core-prod"}, + }, + { + name: "multiple log groups", + input: `[{"value":{"arn":"arn:aws:logs:us-east-1:123:log-group:/ecs/prod","name":"/ecs/prod"}},{"value":{"arn":"arn:aws:logs:us-east-1:123:log-group:/ecs/staging","name":"/ecs/staging"}}]`, + expected: []string{"/ecs/prod", "/ecs/staging"}, + }, + { + name: "empty array", + input: `[]`, + expected: []string{}, + }, + { + name: "invalid JSON", + input: `not json`, + expectError: true, + }, + { + name: "no accountId field", + input: `[{"value":{"arn":"some-arn","name":"/app/logs"}}]`, + expected: []string{"/app/logs"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseCloudWatchLogGroupsResponse([]byte(tt.input), 1024*1024) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseCloudWatchLogGroupFieldsResponse(t *testing.T) { + tests := []struct { + name string + input string + expected []string + expectError bool + }{ + { + name: "valid fields", + input: `[{"value":{"name":"@timestamp","percent":100}},{"value":{"name":"@message","percent":95}}]`, + expected: []string{"@timestamp", "@message"}, + }, + { + name: "single field", + input: `[{"value":{"name":"level","percent":50}}]`, + expected: []string{"level"}, + }, + { + name: "empty array", + input: `[]`, + expected: []string{}, + }, + { + name: "invalid JSON", + input: `{bad`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseCloudWatchLogGroupFieldsResponse([]byte(tt.input), 1024*1024) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFormatLogValue(t *testing.T) { + tests := []struct { + name string + value interface{} + fieldType string + expected string + }{ + {name: "string value", value: "hello world", fieldType: "string", expected: "hello world"}, + {name: "float64 number", value: float64(25.5), fieldType: "number", expected: "25.5"}, + {name: "float64 integer", value: float64(42), fieldType: "number", expected: "42"}, + {name: "float64 timestamp", value: float64(1705312800000), fieldType: "time", + expected: time.UnixMilli(1705312800000).UTC().Format(time.RFC3339Nano)}, + {name: "int64 number", value: int64(42), fieldType: "number", expected: "42"}, + {name: "int64 timestamp", value: int64(1705312800000), fieldType: "time", + expected: time.UnixMilli(1705312800000).UTC().Format(time.RFC3339Nano)}, + {name: "nil value", value: nil, fieldType: "string", expected: ""}, + {name: "bool fallback", value: true, fieldType: "boolean", expected: "true"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, formatLogValue(tt.value, tt.fieldType)) + }) + } +} + +func TestExtractQueryID(t *testing.T) { + t.Run("valid queryId field", func(t *testing.T) { + resp := buildLogsResponse(t, "", + []fieldDef{{Name: "queryId", Type: "string"}}, + [][]interface{}{{"abc-123-query-id"}}, + "", + ) + qid, err := extractQueryID(resp) + require.NoError(t, err) + assert.Equal(t, "abc-123-query-id", qid) + }) + + t.Run("no frames", func(t *testing.T) { + resp := buildLogsResponse(t, "", nil, nil, "") + _, err := extractQueryID(resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "no queryId found") + }) + + t.Run("error in response", func(t *testing.T) { + resp := buildLogsResponse(t, "", nil, nil, "something went wrong") + _, err := extractQueryID(resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "something went wrong") + }) +} + +func TestExtractQueryStatus(t *testing.T) { + t.Run("complete status", func(t *testing.T) { + resp := buildLogsResponse(t, "Complete", nil, nil, "") + assert.Equal(t, "Complete", extractQueryStatus(resp)) + }) + + t.Run("running status", func(t *testing.T) { + resp := buildLogsResponse(t, "Running", nil, nil, "") + assert.Equal(t, "Running", extractQueryStatus(resp)) + }) + + t.Run("no meta", func(t *testing.T) { + resp := buildLogsResponse(t, "", nil, nil, "") + assert.Equal(t, "", extractQueryStatus(resp)) + }) + + t.Run("no frames", func(t *testing.T) { + resp := &cloudWatchQueryResponse{} + data := `{"results":{"A":{"frames":[]}}}` + _ = json.Unmarshal([]byte(data), resp) + assert.Equal(t, "", extractQueryStatus(resp)) + }) +} + +func TestParseLogsQueryResponse(t *testing.T) { + t.Run("complete response with multiple fields", func(t *testing.T) { + fields := []fieldDef{ + {Name: "@timestamp", Type: "time"}, + {Name: "@message", Type: "string"}, + {Name: "level", Type: "string"}, + } + values := [][]interface{}{ + {float64(1705312800000), float64(1705312900000)}, + {"ERROR: something failed", "INFO: all good"}, + {"error", "info"}, + } + resp := buildLogsResponse(t, "Complete", fields, values, "") + + result, err := parseLogsQueryResponse(resp, "fields @timestamp, @message, level", 100) + require.NoError(t, err) + assert.Equal(t, "Complete", result.Status) + assert.Equal(t, 2, result.TotalFound) + assert.Len(t, result.Logs, 2) + + assert.Equal(t, "ERROR: something failed", result.Logs[0].Fields["@message"]) + assert.Equal(t, "error", result.Logs[0].Fields["level"]) + assert.Contains(t, result.Logs[0].Fields["@timestamp"], "2024-01-15") + }) + + t.Run("strips @ptr field", func(t *testing.T) { + fields := []fieldDef{ + {Name: "@timestamp", Type: "time"}, + {Name: "@message", Type: "string"}, + {Name: "@ptr", Type: "string"}, + } + values := [][]interface{}{ + {float64(1705312800000)}, + {"test message"}, + {"some-internal-pointer"}, + } + resp := buildLogsResponse(t, "Complete", fields, values, "") + + result, err := parseLogsQueryResponse(resp, "test", 100) + require.NoError(t, err) + assert.Len(t, result.Logs, 1) + _, hasPtr := result.Logs[0].Fields["@ptr"] + assert.False(t, hasPtr, "@ptr should be stripped") + assert.Equal(t, "test message", result.Logs[0].Fields["@message"]) + }) + + t.Run("strips grafana internal fields", func(t *testing.T) { + fields := []fieldDef{ + {Name: "@timestamp", Type: "time"}, + {Name: "@message", Type: "string"}, + {Name: "__log__grafana_internal__", Type: "string"}, + {Name: "__logstream__grafana_internal__", Type: "string"}, + } + values := [][]interface{}{ + {float64(1705312800000)}, + {"test message"}, + {"442042515479:/ecs/core-prod"}, + {"api/apiContainer/abc123"}, + } + resp := buildLogsResponse(t, "Complete", fields, values, "") + + result, err := parseLogsQueryResponse(resp, "test", 100) + require.NoError(t, err) + assert.Len(t, result.Logs, 1) + _, hasLog := result.Logs[0].Fields["__log__grafana_internal__"] + assert.False(t, hasLog, "__log__grafana_internal__ should be stripped") + _, hasLogStream := result.Logs[0].Fields["__logstream__grafana_internal__"] + assert.False(t, hasLogStream, "__logstream__grafana_internal__ should be stripped") + }) + + t.Run("empty result generates hints", func(t *testing.T) { + fields := []fieldDef{ + {Name: "@timestamp", Type: "time"}, + {Name: "@message", Type: "string"}, + } + values := [][]interface{}{ + {}, + {}, + } + resp := buildLogsResponse(t, "Complete", fields, values, "") + + result, err := parseLogsQueryResponse(resp, "test", 100) + require.NoError(t, err) + assert.Equal(t, 0, result.TotalFound) + assert.NotEmpty(t, result.Hints) + }) + + t.Run("error in response", func(t *testing.T) { + resp := buildLogsResponse(t, "", nil, nil, "query syntax error") + + _, err := parseLogsQueryResponse(resp, "bad query", 100) + require.Error(t, err) + assert.Contains(t, err.Error(), "query syntax error") + }) + + t.Run("respects limit", func(t *testing.T) { + fields := []fieldDef{ + {Name: "@message", Type: "string"}, + } + values := [][]interface{}{ + {"msg1", "msg2", "msg3", "msg4", "msg5"}, + } + resp := buildLogsResponse(t, "Complete", fields, values, "") + + result, err := parseLogsQueryResponse(resp, "test", 3) + require.NoError(t, err) + assert.Len(t, result.Logs, 3) + assert.Equal(t, 3, result.TotalFound) + }) +} + +func TestGenerateCloudWatchLogsEmptyResultHints(t *testing.T) { + hints := generateCloudWatchLogsEmptyResultHints() + + assert.NotEmpty(t, hints) + assert.GreaterOrEqual(t, len(hints), 5) + assert.Contains(t, hints[0], "No log data found") + + hintsStr := strings.Join(hints, " ") + assert.Contains(t, hintsStr, "list_cloudwatch_log_groups") + assert.Contains(t, hintsStr, "list_cloudwatch_log_group_fields") +} + +func TestCloudWatchLogsQueryResult_Structure(t *testing.T) { + result := CloudWatchLogsQueryResult{ + Logs: []CloudWatchLogEntry{ + {Fields: map[string]string{"@timestamp": "2024-01-15T10:00:00Z", "@message": "ERROR: test"}}, + {Fields: map[string]string{"@timestamp": "2024-01-15T10:01:00Z", "@message": "INFO: ok"}}, + }, + Query: "fields @timestamp, @message", + TotalFound: 2, + Status: "Complete", + } + + assert.Len(t, result.Logs, 2) + assert.Equal(t, "ERROR: test", result.Logs[0].Fields["@message"]) + assert.Equal(t, 2, result.TotalFound) + assert.Equal(t, "Complete", result.Status) + assert.Nil(t, result.Hints) +} + +func TestCloudWatchLogsQueryResult_JSONSerialization(t *testing.T) { + t.Run("hints omitted when nil", func(t *testing.T) { + result := CloudWatchLogsQueryResult{ + Logs: []CloudWatchLogEntry{}, + Query: "test", + TotalFound: 0, + Status: "Complete", + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(data, &m)) + _, hasHints := m["hints"] + assert.False(t, hasHints, "hints should be omitted when nil") + }) + + t.Run("hints included when present", func(t *testing.T) { + result := CloudWatchLogsQueryResult{ + Logs: []CloudWatchLogEntry{}, + Query: "test", + TotalFound: 0, + Status: "Complete", + Hints: []string{"hint1"}, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(data, &m)) + _, hasHints := m["hints"] + assert.True(t, hasHints, "hints should be included when present") + }) +} + +// Helper types and functions for building test responses + +type fieldDef struct { + Name string + Type string +} + +// buildLogsResponse constructs a cloudWatchQueryResponse via JSON round-trip, +// avoiding inline anonymous struct matching issues. +func buildLogsResponse(t *testing.T, status string, fields []fieldDef, values [][]interface{}, errMsg string) *cloudWatchQueryResponse { + t.Helper() + + type frameMeta struct { + Custom struct { + Status string `json:"Status"` + } `json:"custom,omitempty"` + } + type schemaField struct { + Name string `json:"name"` + Type string `json:"type"` + } + type frame struct { + Schema struct { + Meta *frameMeta `json:"meta,omitempty"` + Fields []schemaField `json:"fields"` + } `json:"schema"` + Data struct { + Values [][]interface{} `json:"values"` + } `json:"data"` + } + type resultEntry struct { + Frames []frame `json:"frames,omitempty"` + Error string `json:"error,omitempty"` + } + + r := resultEntry{Error: errMsg} + + if fields != nil || status != "" { + f := frame{} + if status != "" { + f.Schema.Meta = &frameMeta{} + f.Schema.Meta.Custom.Status = status + } + if fields != nil { + f.Schema.Fields = make([]schemaField, len(fields)) + for i, fd := range fields { + f.Schema.Fields[i] = schemaField{Name: fd.Name, Type: fd.Type} + } + } + if values != nil { + f.Data.Values = values + } + r.Frames = []frame{f} + } + + wrapper := map[string]map[string]resultEntry{ + "results": {"A": r}, + } + + data, err := json.Marshal(wrapper) + require.NoError(t, err) + + var resp cloudWatchQueryResponse + require.NoError(t, json.Unmarshal(data, &resp)) + return &resp +} diff --git a/tools/cloudwatch_test.go b/tools/cloudwatch_test.go index acad6318..e270c877 100644 --- a/tools/cloudwatch_test.go +++ b/tools/cloudwatch_test.go @@ -256,8 +256,9 @@ func TestCloudWatchMultiFrameStatistics(t *testing.T) { Status int `json:"status,omitempty"` Frames []struct { Schema struct { - Name string `json:"name,omitempty"` - RefID string `json:"refId,omitempty"` + Name string `json:"name,omitempty"` + RefID string `json:"refId,omitempty"` + Meta *cloudWatchFrameMeta `json:"meta,omitempty"` Fields []struct { Name string `json:"name"` Type string `json:"type"` @@ -279,8 +280,9 @@ func TestCloudWatchMultiFrameStatistics(t *testing.T) { // Frame type for convenience type frame = struct { Schema struct { - Name string `json:"name,omitempty"` - RefID string `json:"refId,omitempty"` + Name string `json:"name,omitempty"` + RefID string `json:"refId,omitempty"` + Meta *cloudWatchFrameMeta `json:"meta,omitempty"` Fields []struct { Name string `json:"name"` Type string `json:"type"` @@ -332,8 +334,9 @@ func TestCloudWatchMultiFrameStatistics(t *testing.T) { Status int `json:"status,omitempty"` Frames []struct { Schema struct { - Name string `json:"name,omitempty"` - RefID string `json:"refId,omitempty"` + Name string `json:"name,omitempty"` + RefID string `json:"refId,omitempty"` + Meta *cloudWatchFrameMeta `json:"meta,omitempty"` Fields []struct { Name string `json:"name"` Type string `json:"type"` @@ -579,6 +582,46 @@ func TestCloudWatchAccountIdURLEncoding(t *testing.T) { } } +func TestCloudWatchCustomMetaUnmarshal(t *testing.T) { + tests := []struct { + name string + json string + expectedStatus string + }{ + { + name: "object with Status (Logs Insights response)", + json: `{"results":{"A":{"frames":[{"schema":{"meta":{"custom":{"Status":"Complete"}},"fields":[]},"data":{"values":[]}}]}}}`, + expectedStatus: "Complete", + }, + { + name: "string value (Metrics response)", + json: `{"results":{"A":{"frames":[{"schema":{"meta":{"custom":"timeSeriesQuery"},"fields":[]},"data":{"values":[]}}]}}}`, + expectedStatus: "", + }, + { + name: "no meta field", + json: `{"results":{"A":{"frames":[{"schema":{"fields":[]},"data":{"values":[]}}]}}}`, + expectedStatus: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var resp cloudWatchQueryResponse + err := json.Unmarshal([]byte(tt.json), &resp) + require.NoError(t, err, "unmarshaling should not fail regardless of custom field type") + + for _, r := range resp.Results { + for _, frame := range r.Frames { + if frame.Schema.Meta != nil { + assert.Equal(t, tt.expectedStatus, frame.Schema.Meta.Custom.Status) + } + } + } + }) + } +} + func TestGenerateCloudWatchEmptyResultHints(t *testing.T) { hints := generateCloudWatchEmptyResultHints()