|
| 1 | +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package api |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + "encoding/json" |
| 9 | + "fmt" |
| 10 | + "log/slog" |
| 11 | + "net/http" |
| 12 | + "net/url" |
| 13 | + "strings" |
| 14 | + |
| 15 | + thvregistry "github.com/stacklok/toolhive-core/registry/types" |
| 16 | + "github.com/stacklok/toolhive/pkg/registry/auth" |
| 17 | + "github.com/stacklok/toolhive/pkg/versions" |
| 18 | +) |
| 19 | + |
| 20 | +const skillsBasePath = "/v0.1/x/dev.toolhive/skills" |
| 21 | + |
| 22 | +// SkillsListOptions contains options for listing skills. |
| 23 | +type SkillsListOptions struct { |
| 24 | + // Search is an optional search query to filter skills. |
| 25 | + Search string |
| 26 | + // Limit is the maximum number of skills per page (default: 100). |
| 27 | + Limit int |
| 28 | + // Cursor is the pagination cursor for fetching the next page. |
| 29 | + Cursor string |
| 30 | +} |
| 31 | + |
| 32 | +// SkillsListResult contains a page of skills and pagination info. |
| 33 | +type SkillsListResult struct { |
| 34 | + Skills []*thvregistry.Skill |
| 35 | + NextCursor string |
| 36 | +} |
| 37 | + |
| 38 | +// SkillsClient provides access to the ToolHive Skills extension API. |
| 39 | +type SkillsClient interface { |
| 40 | + // GetSkill retrieves a skill by namespace and name (latest version). |
| 41 | + GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) |
| 42 | + // GetSkillVersion retrieves a specific version of a skill. |
| 43 | + GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) |
| 44 | + // ListSkills retrieves skills with optional filtering and pagination. |
| 45 | + ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) |
| 46 | + // SearchSkills searches for skills matching the query (single page, no auto-pagination). |
| 47 | + SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) |
| 48 | + // ListSkillVersions lists all versions of a specific skill. |
| 49 | + ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) |
| 50 | +} |
| 51 | + |
| 52 | +// NewSkillsClient creates a new ToolHive Skills extension API client. |
| 53 | +// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject |
| 54 | +// Bearer tokens into all requests. |
| 55 | +func NewSkillsClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (SkillsClient, error) { |
| 56 | + httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource) |
| 57 | + if err != nil { |
| 58 | + return nil, err |
| 59 | + } |
| 60 | + |
| 61 | + // Ensure base URL doesn't have trailing slash |
| 62 | + baseURL = strings.TrimRight(baseURL, "/") |
| 63 | + |
| 64 | + return &mcpSkillsClient{ |
| 65 | + baseURL: baseURL, |
| 66 | + httpClient: httpClient, |
| 67 | + userAgent: versions.GetUserAgent(), |
| 68 | + }, nil |
| 69 | +} |
| 70 | + |
| 71 | +// GetSkill retrieves a skill by namespace and name (latest version). |
| 72 | +func (c *mcpSkillsClient) GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) { |
| 73 | + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name)) |
| 74 | + if err != nil { |
| 75 | + return nil, fmt.Errorf("failed to build skills URL: %w", err) |
| 76 | + } |
| 77 | + |
| 78 | + var skill thvregistry.Skill |
| 79 | + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { |
| 80 | + return nil, err |
| 81 | + } |
| 82 | + return &skill, nil |
| 83 | +} |
| 84 | + |
| 85 | +// GetSkillVersion retrieves a specific version of a skill. |
| 86 | +func (c *mcpSkillsClient) GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) { |
| 87 | + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, |
| 88 | + url.PathEscape(namespace), url.PathEscape(name), |
| 89 | + "versions", url.PathEscape(version)) |
| 90 | + if err != nil { |
| 91 | + return nil, fmt.Errorf("failed to build skills URL: %w", err) |
| 92 | + } |
| 93 | + |
| 94 | + var skill thvregistry.Skill |
| 95 | + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { |
| 96 | + return nil, err |
| 97 | + } |
| 98 | + return &skill, nil |
| 99 | +} |
| 100 | + |
| 101 | +// ListSkills retrieves skills with optional filtering and pagination. |
| 102 | +// It auto-paginates through all available pages, concatenating results. |
| 103 | +func (c *mcpSkillsClient) ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) { |
| 104 | + if opts == nil { |
| 105 | + opts = &SkillsListOptions{} |
| 106 | + } |
| 107 | + if opts.Limit == 0 { |
| 108 | + opts.Limit = 100 |
| 109 | + } |
| 110 | + |
| 111 | + var allSkills []*thvregistry.Skill |
| 112 | + cursor := opts.Cursor |
| 113 | + |
| 114 | + // Pagination loop - continue until no more cursors |
| 115 | + for { |
| 116 | + page, nextCursor, err := c.fetchSkillsPage(ctx, cursor, opts) |
| 117 | + if err != nil { |
| 118 | + return nil, err |
| 119 | + } |
| 120 | + |
| 121 | + allSkills = append(allSkills, page...) |
| 122 | + |
| 123 | + // Check if we have more pages |
| 124 | + if nextCursor == "" { |
| 125 | + break |
| 126 | + } |
| 127 | + |
| 128 | + cursor = nextCursor |
| 129 | + |
| 130 | + // Safety limit: prevent infinite loops |
| 131 | + if len(allSkills) > 10000 { |
| 132 | + return nil, fmt.Errorf("exceeded maximum skills limit (10000)") |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + return &SkillsListResult{ |
| 137 | + Skills: allSkills, |
| 138 | + }, nil |
| 139 | +} |
| 140 | + |
| 141 | +// SearchSkills searches for skills matching the query. |
| 142 | +// Returns a single page of results (no auto-pagination). |
| 143 | +func (c *mcpSkillsClient) SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) { |
| 144 | + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) |
| 145 | + if err != nil { |
| 146 | + return nil, fmt.Errorf("failed to build skills URL: %w", err) |
| 147 | + } |
| 148 | + params := url.Values{} |
| 149 | + params.Add("search", query) |
| 150 | + |
| 151 | + endpoint := basePath + "?" + params.Encode() |
| 152 | + |
| 153 | + var listResp skillsListResponse |
| 154 | + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { |
| 155 | + return nil, err |
| 156 | + } |
| 157 | + |
| 158 | + return &SkillsListResult{ |
| 159 | + Skills: listResp.Skills, |
| 160 | + NextCursor: listResp.Metadata.NextCursor, |
| 161 | + }, nil |
| 162 | +} |
| 163 | + |
| 164 | +// ListSkillVersions lists all versions of a specific skill. |
| 165 | +func (c *mcpSkillsClient) ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) { |
| 166 | + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name), "versions") |
| 167 | + if err != nil { |
| 168 | + return nil, fmt.Errorf("failed to build skills URL: %w", err) |
| 169 | + } |
| 170 | + |
| 171 | + var listResp skillsListResponse |
| 172 | + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { |
| 173 | + return nil, err |
| 174 | + } |
| 175 | + |
| 176 | + return &SkillsListResult{ |
| 177 | + Skills: listResp.Skills, |
| 178 | + NextCursor: listResp.Metadata.NextCursor, |
| 179 | + }, nil |
| 180 | +} |
| 181 | + |
| 182 | +// mcpSkillsClient implements the SkillsClient interface. |
| 183 | +type mcpSkillsClient struct { |
| 184 | + baseURL string |
| 185 | + httpClient *http.Client |
| 186 | + userAgent string |
| 187 | +} |
| 188 | + |
| 189 | +// skillsListResponse is the wire format for list/search responses. |
| 190 | +type skillsListResponse struct { |
| 191 | + Skills []*thvregistry.Skill `json:"skills"` |
| 192 | + Metadata struct { |
| 193 | + Count int `json:"count"` |
| 194 | + NextCursor string `json:"nextCursor"` |
| 195 | + } `json:"metadata"` |
| 196 | +} |
| 197 | + |
| 198 | +// doSkillsGet performs an HTTP GET request and decodes the JSON response into dest. |
| 199 | +func (c *mcpSkillsClient) doSkillsGet(ctx context.Context, endpoint string, dest any) error { |
| 200 | + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) |
| 201 | + if err != nil { |
| 202 | + return fmt.Errorf("failed to create request: %w", err) |
| 203 | + } |
| 204 | + req.Header.Set("User-Agent", c.userAgent) |
| 205 | + |
| 206 | + resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL from configured registry |
| 207 | + if err != nil { |
| 208 | + return fmt.Errorf("failed to execute request: %w", err) |
| 209 | + } |
| 210 | + defer func() { |
| 211 | + if err := resp.Body.Close(); err != nil { |
| 212 | + slog.Debug("failed to close response body", "error", err) |
| 213 | + } |
| 214 | + }() |
| 215 | + |
| 216 | + if resp.StatusCode != http.StatusOK { |
| 217 | + return newRegistryHTTPError(resp) |
| 218 | + } |
| 219 | + |
| 220 | + if err := json.NewDecoder(resp.Body).Decode(dest); err != nil { |
| 221 | + return fmt.Errorf("failed to decode response: %w", err) |
| 222 | + } |
| 223 | + return nil |
| 224 | +} |
| 225 | + |
| 226 | +// fetchSkillsPage fetches a single page of skills. |
| 227 | +func (c *mcpSkillsClient) fetchSkillsPage( |
| 228 | + ctx context.Context, cursor string, opts *SkillsListOptions, |
| 229 | +) ([]*thvregistry.Skill, string, error) { |
| 230 | + params := url.Values{} |
| 231 | + if cursor != "" { |
| 232 | + params.Add("cursor", cursor) |
| 233 | + } |
| 234 | + if opts.Limit > 0 { |
| 235 | + params.Add("limit", fmt.Sprintf("%d", opts.Limit)) |
| 236 | + } |
| 237 | + if opts.Search != "" { |
| 238 | + params.Add("search", opts.Search) |
| 239 | + } |
| 240 | + |
| 241 | + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) |
| 242 | + if err != nil { |
| 243 | + return nil, "", fmt.Errorf("failed to build skills URL: %w", err) |
| 244 | + } |
| 245 | + endpoint := func() string { |
| 246 | + if len(params) > 0 { |
| 247 | + return basePath + "?" + params.Encode() |
| 248 | + } |
| 249 | + return basePath |
| 250 | + }() |
| 251 | + |
| 252 | + var listResp skillsListResponse |
| 253 | + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { |
| 254 | + return nil, "", err |
| 255 | + } |
| 256 | + |
| 257 | + return listResp.Skills, listResp.Metadata.NextCursor, nil |
| 258 | +} |
0 commit comments