Skip to content

Commit 79a2e6c

Browse files
JAORMXclaude
andauthored
Add Skills API client for registry extension (#4173)
The toolhive-registry-server exposes a Skills API as a ToolHive-specific extension under /v0.1/x/dev.toolhive/skills. This adds an HTTP client to query that API, following the same patterns as the existing server client. - Extract shared HTTP client builder and error types into shared.go so both the server client and new skills client reuse the same security controls (private IP policy, auth token injection, error handling with LimitReader) - Add SkillsClient interface with GetSkill, GetSkillVersion, ListSkills, SearchSkills, and ListSkillVersions methods - Add RegistryHTTPError with Unwrap() for structured 401/403 handling - Migrate existing server client to use the shared error type - Add comprehensive table-driven tests with httptest Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f4f4e82 commit 79a2e6c

4 files changed

Lines changed: 892 additions & 16 deletions

File tree

pkg/registry/api/client.go

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
v0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
1717
"gopkg.in/yaml.v3"
1818

19-
"github.com/stacklok/toolhive/pkg/networking"
2019
"github.com/stacklok/toolhive/pkg/registry/auth"
2120
"github.com/stacklok/toolhive/pkg/versions"
2221
)
@@ -63,18 +62,11 @@ type mcpRegistryClient struct {
6362
func NewClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (Client, error) {
6463
// Build HTTP client with security controls
6564
// If private IPs are allowed, also allow HTTP (for localhost testing)
66-
builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp)
67-
if allowPrivateIp {
68-
builder = builder.WithInsecureAllowHTTP(true)
69-
}
70-
httpClient, err := builder.Build()
65+
httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource)
7166
if err != nil {
72-
return nil, fmt.Errorf("failed to build HTTP client: %w", err)
67+
return nil, err
7368
}
7469

75-
// Wrap transport with auth if token source is provided
76-
httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource)
77-
7870
// Ensure base URL doesn't have trailing slash
7971
if baseURL[len(baseURL)-1] == '/' {
8072
baseURL = baseURL[:len(baseURL)-1]
@@ -112,8 +104,7 @@ func (c *mcpRegistryClient) GetServer(ctx context.Context, name string) (*v0.Ser
112104
}()
113105

114106
if resp.StatusCode != http.StatusOK {
115-
body, _ := io.ReadAll(resp.Body)
116-
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
107+
return nil, newRegistryHTTPError(resp)
117108
}
118109

119110
var serverResp v0.ServerResponse
@@ -207,8 +198,7 @@ func (c *mcpRegistryClient) fetchServersPage(
207198
}()
208199

209200
if resp.StatusCode != http.StatusOK {
210-
body, _ := io.ReadAll(resp.Body)
211-
return nil, "", fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
201+
return nil, "", newRegistryHTTPError(resp)
212202
}
213203

214204
var listResp v0.ServerListResponse
@@ -252,8 +242,7 @@ func (c *mcpRegistryClient) SearchServers(ctx context.Context, query string) ([]
252242
}()
253243

254244
if resp.StatusCode != http.StatusOK {
255-
body, _ := io.ReadAll(resp.Body)
256-
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
245+
return nil, newRegistryHTTPError(resp)
257246
}
258247

259248
var listResp v0.ServerListResponse

pkg/registry/api/shared.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package api
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
12+
"github.com/stacklok/toolhive/pkg/networking"
13+
"github.com/stacklok/toolhive/pkg/registry/auth"
14+
)
15+
16+
const maxErrorBodySize = 4096
17+
18+
// ErrRegistryUnauthorized is a sentinel error for 401/403 responses from registry APIs.
19+
var ErrRegistryUnauthorized = errors.New("registry authentication failed")
20+
21+
// RegistryHTTPError represents an HTTP error from a registry API endpoint.
22+
type RegistryHTTPError struct {
23+
StatusCode int
24+
Body string
25+
URL string
26+
}
27+
28+
func (e *RegistryHTTPError) Error() string {
29+
return fmt.Sprintf("registry API returned status %d for %s: %s", e.StatusCode, e.URL, e.Body)
30+
}
31+
32+
// Unwrap returns ErrRegistryUnauthorized for 401/403 status codes,
33+
// allowing callers to use errors.Is(err, ErrRegistryUnauthorized).
34+
func (e *RegistryHTTPError) Unwrap() error {
35+
if e.StatusCode == http.StatusUnauthorized || e.StatusCode == http.StatusForbidden {
36+
return ErrRegistryUnauthorized
37+
}
38+
return nil
39+
}
40+
41+
// buildHTTPClient creates an HTTP client with security controls and optional auth.
42+
// If allowPrivateIp is true, HTTP (non-HTTPS) is also allowed for localhost testing.
43+
func buildHTTPClient(allowPrivateIp bool, tokenSource auth.TokenSource) (*http.Client, error) {
44+
builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp)
45+
if allowPrivateIp {
46+
builder = builder.WithInsecureAllowHTTP(true)
47+
}
48+
httpClient, err := builder.Build()
49+
if err != nil {
50+
return nil, fmt.Errorf("failed to build HTTP client: %w", err)
51+
}
52+
httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource)
53+
return httpClient, nil
54+
}
55+
56+
// newRegistryHTTPError reads the response body (limited) and returns a RegistryHTTPError.
57+
func newRegistryHTTPError(resp *http.Response) *RegistryHTTPError {
58+
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize))
59+
return &RegistryHTTPError{
60+
StatusCode: resp.StatusCode,
61+
Body: string(body),
62+
URL: resp.Request.URL.String(),
63+
}
64+
}

pkg/registry/api/skills_client.go

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package api
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"fmt"
10+
"log/slog"
11+
"net/http"
12+
"net/url"
13+
"strings"
14+
15+
thvregistry "github.com/stacklok/toolhive-core/registry/types"
16+
"github.com/stacklok/toolhive/pkg/registry/auth"
17+
"github.com/stacklok/toolhive/pkg/versions"
18+
)
19+
20+
const skillsBasePath = "/v0.1/x/dev.toolhive/skills"
21+
22+
// SkillsListOptions contains options for listing skills.
23+
type SkillsListOptions struct {
24+
// Search is an optional search query to filter skills.
25+
Search string
26+
// Limit is the maximum number of skills per page (default: 100).
27+
Limit int
28+
// Cursor is the pagination cursor for fetching the next page.
29+
Cursor string
30+
}
31+
32+
// SkillsListResult contains a page of skills and pagination info.
33+
type SkillsListResult struct {
34+
Skills []*thvregistry.Skill
35+
NextCursor string
36+
}
37+
38+
// SkillsClient provides access to the ToolHive Skills extension API.
39+
type SkillsClient interface {
40+
// GetSkill retrieves a skill by namespace and name (latest version).
41+
GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error)
42+
// GetSkillVersion retrieves a specific version of a skill.
43+
GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error)
44+
// ListSkills retrieves skills with optional filtering and pagination.
45+
ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error)
46+
// SearchSkills searches for skills matching the query (single page, no auto-pagination).
47+
SearchSkills(ctx context.Context, query string) (*SkillsListResult, error)
48+
// ListSkillVersions lists all versions of a specific skill.
49+
ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error)
50+
}
51+
52+
// NewSkillsClient creates a new ToolHive Skills extension API client.
53+
// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject
54+
// Bearer tokens into all requests.
55+
func NewSkillsClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (SkillsClient, error) {
56+
httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
// Ensure base URL doesn't have trailing slash
62+
baseURL = strings.TrimRight(baseURL, "/")
63+
64+
return &mcpSkillsClient{
65+
baseURL: baseURL,
66+
httpClient: httpClient,
67+
userAgent: versions.GetUserAgent(),
68+
}, nil
69+
}
70+
71+
// GetSkill retrieves a skill by namespace and name (latest version).
72+
func (c *mcpSkillsClient) GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) {
73+
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name))
74+
if err != nil {
75+
return nil, fmt.Errorf("failed to build skills URL: %w", err)
76+
}
77+
78+
var skill thvregistry.Skill
79+
if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil {
80+
return nil, err
81+
}
82+
return &skill, nil
83+
}
84+
85+
// GetSkillVersion retrieves a specific version of a skill.
86+
func (c *mcpSkillsClient) GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) {
87+
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath,
88+
url.PathEscape(namespace), url.PathEscape(name),
89+
"versions", url.PathEscape(version))
90+
if err != nil {
91+
return nil, fmt.Errorf("failed to build skills URL: %w", err)
92+
}
93+
94+
var skill thvregistry.Skill
95+
if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil {
96+
return nil, err
97+
}
98+
return &skill, nil
99+
}
100+
101+
// ListSkills retrieves skills with optional filtering and pagination.
102+
// It auto-paginates through all available pages, concatenating results.
103+
func (c *mcpSkillsClient) ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) {
104+
if opts == nil {
105+
opts = &SkillsListOptions{}
106+
}
107+
if opts.Limit == 0 {
108+
opts.Limit = 100
109+
}
110+
111+
var allSkills []*thvregistry.Skill
112+
cursor := opts.Cursor
113+
114+
// Pagination loop - continue until no more cursors
115+
for {
116+
page, nextCursor, err := c.fetchSkillsPage(ctx, cursor, opts)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
allSkills = append(allSkills, page...)
122+
123+
// Check if we have more pages
124+
if nextCursor == "" {
125+
break
126+
}
127+
128+
cursor = nextCursor
129+
130+
// Safety limit: prevent infinite loops
131+
if len(allSkills) > 10000 {
132+
return nil, fmt.Errorf("exceeded maximum skills limit (10000)")
133+
}
134+
}
135+
136+
return &SkillsListResult{
137+
Skills: allSkills,
138+
}, nil
139+
}
140+
141+
// SearchSkills searches for skills matching the query.
142+
// Returns a single page of results (no auto-pagination).
143+
func (c *mcpSkillsClient) SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) {
144+
basePath, err := url.JoinPath(c.baseURL, skillsBasePath)
145+
if err != nil {
146+
return nil, fmt.Errorf("failed to build skills URL: %w", err)
147+
}
148+
params := url.Values{}
149+
params.Add("search", query)
150+
151+
endpoint := basePath + "?" + params.Encode()
152+
153+
var listResp skillsListResponse
154+
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
155+
return nil, err
156+
}
157+
158+
return &SkillsListResult{
159+
Skills: listResp.Skills,
160+
NextCursor: listResp.Metadata.NextCursor,
161+
}, nil
162+
}
163+
164+
// ListSkillVersions lists all versions of a specific skill.
165+
func (c *mcpSkillsClient) ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) {
166+
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name), "versions")
167+
if err != nil {
168+
return nil, fmt.Errorf("failed to build skills URL: %w", err)
169+
}
170+
171+
var listResp skillsListResponse
172+
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
173+
return nil, err
174+
}
175+
176+
return &SkillsListResult{
177+
Skills: listResp.Skills,
178+
NextCursor: listResp.Metadata.NextCursor,
179+
}, nil
180+
}
181+
182+
// mcpSkillsClient implements the SkillsClient interface.
183+
type mcpSkillsClient struct {
184+
baseURL string
185+
httpClient *http.Client
186+
userAgent string
187+
}
188+
189+
// skillsListResponse is the wire format for list/search responses.
190+
type skillsListResponse struct {
191+
Skills []*thvregistry.Skill `json:"skills"`
192+
Metadata struct {
193+
Count int `json:"count"`
194+
NextCursor string `json:"nextCursor"`
195+
} `json:"metadata"`
196+
}
197+
198+
// doSkillsGet performs an HTTP GET request and decodes the JSON response into dest.
199+
func (c *mcpSkillsClient) doSkillsGet(ctx context.Context, endpoint string, dest any) error {
200+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
201+
if err != nil {
202+
return fmt.Errorf("failed to create request: %w", err)
203+
}
204+
req.Header.Set("User-Agent", c.userAgent)
205+
206+
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL from configured registry
207+
if err != nil {
208+
return fmt.Errorf("failed to execute request: %w", err)
209+
}
210+
defer func() {
211+
if err := resp.Body.Close(); err != nil {
212+
slog.Debug("failed to close response body", "error", err)
213+
}
214+
}()
215+
216+
if resp.StatusCode != http.StatusOK {
217+
return newRegistryHTTPError(resp)
218+
}
219+
220+
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
221+
return fmt.Errorf("failed to decode response: %w", err)
222+
}
223+
return nil
224+
}
225+
226+
// fetchSkillsPage fetches a single page of skills.
227+
func (c *mcpSkillsClient) fetchSkillsPage(
228+
ctx context.Context, cursor string, opts *SkillsListOptions,
229+
) ([]*thvregistry.Skill, string, error) {
230+
params := url.Values{}
231+
if cursor != "" {
232+
params.Add("cursor", cursor)
233+
}
234+
if opts.Limit > 0 {
235+
params.Add("limit", fmt.Sprintf("%d", opts.Limit))
236+
}
237+
if opts.Search != "" {
238+
params.Add("search", opts.Search)
239+
}
240+
241+
basePath, err := url.JoinPath(c.baseURL, skillsBasePath)
242+
if err != nil {
243+
return nil, "", fmt.Errorf("failed to build skills URL: %w", err)
244+
}
245+
endpoint := func() string {
246+
if len(params) > 0 {
247+
return basePath + "?" + params.Encode()
248+
}
249+
return basePath
250+
}()
251+
252+
var listResp skillsListResponse
253+
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
254+
return nil, "", err
255+
}
256+
257+
return listResp.Skills, listResp.Metadata.NextCursor, nil
258+
}

0 commit comments

Comments
 (0)