Skip to content

Commit fe313c3

Browse files
authored
Merge pull request #1970 from dgageot/board/multiple-lsp-7f7e9871
Add LSP multiplexer to support multiple LSP toolsets
2 parents 6047dab + dc49baf commit fe313c3

File tree

7 files changed

+494
-116
lines changed

7 files changed

+494
-116
lines changed

e2e/binary/binary_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ func TestExecMissingKeys(t *testing.T) {
4141
require.Contains(t, res.Stderr, "OPENAI_API_KEY")
4242
})
4343
}
44+
4445
func TestAutoComplete(t *testing.T) {
4546
t.Run("cli plugin auto-complete docker-agent", func(t *testing.T) {
4647
res, err := Exec(binDir+"/docker-agent", "__complete", "ser")

pkg/teamloader/teamloader.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,9 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates
427427
// getToolsForAgent returns the tool definitions for an agent based on its configuration
428428
func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry, configName string) ([]tools.ToolSet, []string) {
429429
var (
430-
toolSets []tools.ToolSet
431-
warnings []string
430+
toolSets []tools.ToolSet
431+
warnings []string
432+
lspBackends []builtin.LSPBackend
432433
)
433434

434435
deferredToolset := builtin.NewDeferredToolset()
@@ -460,9 +461,29 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
460461
}
461462
}
462463

464+
// Collect LSP backends for multiplexing when there are multiple.
465+
// Instead of adding them individually (which causes duplicate tool names),
466+
// they are combined into a single LSPMultiplexer after the loop.
467+
if toolset.Type == "lsp" {
468+
if lspTool, ok := tool.(*builtin.LSPTool); ok {
469+
lspBackends = append(lspBackends, builtin.LSPBackend{LSP: lspTool, Toolset: wrapped})
470+
continue
471+
}
472+
slog.Warn("Toolset configured as type 'lsp' but registry returned unexpected type; treating as regular toolset",
473+
"type", fmt.Sprintf("%T", tool), "command", toolset.Command)
474+
}
475+
463476
toolSets = append(toolSets, wrapped)
464477
}
465478

479+
// Merge LSP backends: if there are multiple, combine them into a single
480+
// multiplexer so the LLM sees one set of lsp_* tools instead of duplicates.
481+
if len(lspBackends) > 1 {
482+
toolSets = append(toolSets, builtin.NewLSPMultiplexer(lspBackends))
483+
} else if len(lspBackends) == 1 {
484+
toolSets = append(toolSets, lspBackends[0].Toolset)
485+
}
486+
466487
if deferredToolset.HasSources() {
467488
toolSets = append(toolSets, deferredToolset)
468489
}

pkg/teamloader/teamloader_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,88 @@ agents:
386386
assert.Equal(t, expected, rootAgent.AddPromptFiles())
387387
}
388388

389+
func TestGetToolsForAgent_MultipleLSPToolsetsAreCombined(t *testing.T) {
390+
t.Parallel()
391+
392+
a := &latest.AgentConfig{
393+
Instruction: "test",
394+
Toolsets: []latest.Toolset{
395+
{
396+
Type: "lsp",
397+
Command: "gopls",
398+
Version: "golang/tools@v0.21.0",
399+
FileTypes: []string{".go"},
400+
},
401+
{
402+
Type: "lsp",
403+
Command: "gopls",
404+
Version: "golang/tools@v0.21.0",
405+
FileTypes: []string{".mod"},
406+
},
407+
},
408+
}
409+
410+
runConfig := config.RuntimeConfig{
411+
EnvProviderForTests: &noEnvProvider{},
412+
}
413+
414+
got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry(), "test-config")
415+
require.Empty(t, warnings)
416+
417+
// Should have exactly one toolset (the multiplexer)
418+
require.Len(t, got, 1)
419+
420+
// Verify that we get no duplicate tool names
421+
allTools, err := got[0].Tools(t.Context())
422+
require.NoError(t, err)
423+
424+
seen := make(map[string]bool)
425+
for _, tool := range allTools {
426+
assert.False(t, seen[tool.Name], "duplicate tool name: %s", tool.Name)
427+
seen[tool.Name] = true
428+
}
429+
430+
// Verify LSP tools are present
431+
assert.True(t, seen["lsp_hover"])
432+
assert.True(t, seen["lsp_definition"])
433+
}
434+
435+
func TestGetToolsForAgent_SingleLSPToolsetNotWrapped(t *testing.T) {
436+
t.Parallel()
437+
438+
a := &latest.AgentConfig{
439+
Instruction: "test",
440+
Toolsets: []latest.Toolset{
441+
{
442+
Type: "lsp",
443+
Command: "gopls",
444+
Version: "golang/tools@v0.21.0",
445+
FileTypes: []string{".go"},
446+
},
447+
},
448+
}
449+
450+
runConfig := config.RuntimeConfig{
451+
EnvProviderForTests: &noEnvProvider{},
452+
}
453+
454+
got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry(), "test-config")
455+
require.Empty(t, warnings)
456+
457+
// Should have exactly one toolset that provides LSP tools.
458+
require.Len(t, got, 1)
459+
460+
allTools, err := got[0].Tools(t.Context())
461+
require.NoError(t, err)
462+
463+
var names []string
464+
for _, tool := range allTools {
465+
names = append(names, tool.Name)
466+
}
467+
assert.Contains(t, names, "lsp_hover")
468+
assert.Contains(t, names, "lsp_definition")
469+
}
470+
389471
func TestExternalDepthContext(t *testing.T) {
390472
t.Parallel()
391473

pkg/toolinstall/registry.go

Lines changed: 58 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package toolinstall
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"fmt"
@@ -10,9 +11,9 @@ import (
1011
"path/filepath"
1112
"strings"
1213
"sync"
13-
"time"
1414

1515
"github.com/goccy/go-yaml"
16+
"github.com/natefinch/atomic"
1617
)
1718

1819
// githubToken returns a GitHub personal access token from the environment,
@@ -35,8 +36,8 @@ func setGitHubAuth(req *http.Request) {
3536
}
3637

3738
const (
38-
registryBaseURL = "https://raw.githubusercontent.com/aquaproj/aqua-registry/main"
39-
registryCacheTTL = 24 * time.Hour
39+
registryBaseURL = "https://raw.githubusercontent.com/aquaproj/aqua-registry/main"
40+
registryIndexFile = "registry.yaml"
4041
)
4142

4243
// Package represents a parsed aqua registry package definition.
@@ -104,11 +105,6 @@ type Registry struct {
104105
httpClient *http.Client
105106
baseURL string
106107
cacheDir string
107-
108-
// In-memory cache for the parsed registry index, populated once via sync.Once.
109-
indexOnce sync.Once
110-
cachedIndex *registryIndex
111-
indexErr error
112108
}
113109

114110
var (
@@ -157,8 +153,8 @@ func (r *Registry) LookupByName(ctx context.Context, name string) (*Package, err
157153
}
158154
}
159155

160-
// Fallback: fetch the per-package YAML file.
161-
data, err := r.fetchCached(ctx, fmt.Sprintf("pkgs/%s/%s/registry.yaml", owner, repo), 0)
156+
// Fallback: fetch the per-package YAML file directly (no caching).
157+
data, err := r.getBody(ctx, r.baseURL+"/"+fmt.Sprintf("pkgs/%s/%s/registry.yaml", owner, repo))
162158
if err != nil {
163159
return nil, fmt.Errorf("fetching package %s: %w", name, err)
164160
}
@@ -213,88 +209,32 @@ func providesCommand(pkg *Package, command string) bool {
213209
return false
214210
}
215211

216-
// fetchIndex fetches and parses the full registry index, with caching.
217-
// The parsed result is cached in memory so that repeated calls within the
218-
// same Registry instance skip both the HTTP fetch and YAML deserialization.
212+
// fetchIndex fetches and parses the full registry index.
213+
// The raw YAML is cached to disk; on fetch failure the cached copy is used.
214+
// The YAML is re-parsed on every call — there is no in-memory cache.
219215
func (r *Registry) fetchIndex(ctx context.Context) (*registryIndex, error) {
220-
r.indexOnce.Do(func() {
221-
var data []byte
222-
data, r.indexErr = r.fetchCached(ctx, "registry.yaml", registryCacheTTL)
223-
if r.indexErr != nil {
224-
return
225-
}
226-
227-
var index registryIndex
228-
if err := yaml.Unmarshal(data, &index); err != nil {
229-
r.indexErr = fmt.Errorf("parsing registry index: %w", err)
230-
return
231-
}
232-
r.cachedIndex = &index
233-
})
234-
235-
return r.cachedIndex, r.indexErr
236-
}
237-
238-
// fetchCached fetches a file from the registry, using a local file cache.
239-
// A ttl of 0 means the cache never expires.
240-
func (r *Registry) fetchCached(ctx context.Context, path string, ttl time.Duration) ([]byte, error) {
241-
cachePath := filepath.Join(r.cacheDir, path)
242-
243-
// Return cached data if still fresh.
244-
if info, err := os.Stat(cachePath); err == nil {
245-
if ttl == 0 || time.Since(info.ModTime()) < ttl {
246-
return os.ReadFile(cachePath)
247-
}
248-
}
249-
250-
// Fetch from remote.
251-
url := r.baseURL + "/" + path
252-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
253-
if err != nil {
254-
if data, readErr := os.ReadFile(cachePath); readErr == nil {
255-
return data, nil
256-
}
257-
return nil, fmt.Errorf("creating request for %s: %w", url, err)
258-
}
259-
setGitHubAuth(req)
216+
cachePath := filepath.Join(r.cacheDir, registryIndexFile)
260217

261-
resp, err := r.httpClient.Do(req)
218+
data, err := r.getBody(ctx, r.baseURL+"/"+registryIndexFile)
262219
if err != nil {
263-
if data, readErr := os.ReadFile(cachePath); readErr == nil {
264-
return data, nil // stale cache beats no data
220+
// Fallback to stale disk cache.
221+
if cached, readErr := os.ReadFile(cachePath); readErr == nil {
222+
data = cached
223+
} else {
224+
return nil, err
265225
}
266-
return nil, fmt.Errorf("fetching %s: %w", url, err)
226+
} else {
227+
// Best-effort: persist to disk for future fallback.
228+
_ = os.MkdirAll(filepath.Dir(cachePath), 0o755)
229+
_ = atomic.WriteFile(cachePath, bytes.NewReader(data))
267230
}
268-
defer resp.Body.Close()
269231

270-
if resp.StatusCode != http.StatusOK {
271-
if data, readErr := os.ReadFile(cachePath); readErr == nil {
272-
return data, nil
273-
}
274-
return nil, fmt.Errorf("fetching %s: HTTP %d", url, resp.StatusCode)
232+
var index registryIndex
233+
if err := yaml.Unmarshal(data, &index); err != nil {
234+
return nil, fmt.Errorf("parsing registry index: %w", err)
275235
}
276236

277-
data, err := io.ReadAll(resp.Body)
278-
if err != nil {
279-
return nil, fmt.Errorf("reading response from %s: %w", url, err)
280-
}
281-
282-
// Write to cache atomically (best-effort): write to a temp file in the
283-
// same directory, then rename. This avoids races when multiple goroutines
284-
// fetch the same path concurrently.
285-
if err := os.MkdirAll(filepath.Dir(cachePath), 0o755); err == nil {
286-
if tmpFile, tmpErr := os.CreateTemp(filepath.Dir(cachePath), ".cache-*.tmp"); tmpErr == nil {
287-
if _, writeErr := tmpFile.Write(data); writeErr == nil {
288-
tmpFile.Close()
289-
_ = os.Rename(tmpFile.Name(), cachePath)
290-
} else {
291-
tmpFile.Close()
292-
_ = os.Remove(tmpFile.Name())
293-
}
294-
}
295-
}
296-
297-
return data, nil
237+
return &index, nil
298238
}
299239

300240
// githubRelease represents the relevant fields from the GitHub releases API.
@@ -307,7 +247,7 @@ func (r *Registry) latestVersion(ctx context.Context, owner, repo string) (strin
307247
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo)
308248

309249
var release githubRelease
310-
if err := r.fetchGitHubJSON(ctx, url, &release); err != nil {
250+
if err := r.getJSON(ctx, url, &release); err != nil {
311251
return "", fmt.Errorf("fetching latest release for %s/%s: %w", owner, repo, err)
312252
}
313253

@@ -324,7 +264,7 @@ func (r *Registry) latestVersionFiltered(ctx context.Context, owner, repo, tagPr
324264
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases?per_page=50", owner, repo)
325265

326266
var releases []githubRelease
327-
if err := r.fetchGitHubJSON(ctx, url, &releases); err != nil {
267+
if err := r.getJSON(ctx, url, &releases); err != nil {
328268
return "", fmt.Errorf("fetching releases for %s/%s: %w", owner, repo, err)
329269
}
330270

@@ -337,45 +277,59 @@ func (r *Registry) latestVersionFiltered(ctx context.Context, owner, repo, tagPr
337277
return "", fmt.Errorf("no release found for %s/%s with tag prefix %q", owner, repo, tagPrefix)
338278
}
339279

340-
// fetchGitHubJSON fetches a GitHub API endpoint and decodes the JSON response.
341-
func (r *Registry) fetchGitHubJSON(ctx context.Context, url string, target any) error {
280+
// doGet performs an authenticated GET request and returns the response.
281+
// The caller is responsible for closing the response body.
282+
func (r *Registry) doGet(ctx context.Context, url string, headers map[string]string) (*http.Response, error) {
342283
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
343284
if err != nil {
344-
return err
285+
return nil, err
286+
}
287+
for k, v := range headers {
288+
req.Header.Set(k, v)
345289
}
346-
req.Header.Set("Accept", "application/vnd.github+json")
347290
setGitHubAuth(req)
348291

349292
resp, err := r.httpClient.Do(req)
350293
if err != nil {
351-
return err
294+
return nil, err
352295
}
353-
defer resp.Body.Close()
354296

355297
if resp.StatusCode != http.StatusOK {
356-
return fmt.Errorf("HTTP %d", resp.StatusCode)
298+
resp.Body.Close()
299+
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
357300
}
358301

359-
return json.NewDecoder(resp.Body).Decode(target)
302+
return resp, nil
360303
}
361304

362-
// download opens an HTTP connection to the given URL and returns the
363-
// response body as an io.ReadCloser. The caller is responsible for closing it.
364-
func (r *Registry) download(ctx context.Context, url string) (io.ReadCloser, error) {
365-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
305+
// getBody performs a GET request and returns the full response body.
306+
func (r *Registry) getBody(ctx context.Context, url string) ([]byte, error) {
307+
resp, err := r.doGet(ctx, url, nil)
366308
if err != nil {
367309
return nil, err
368310
}
369-
setGitHubAuth(req)
311+
defer resp.Body.Close()
370312

371-
resp, err := r.httpClient.Do(req)
313+
return io.ReadAll(resp.Body)
314+
}
315+
316+
// getJSON performs a GET request and decodes the JSON response into target.
317+
func (r *Registry) getJSON(ctx context.Context, url string, target any) error {
318+
resp, err := r.doGet(ctx, url, map[string]string{"Accept": "application/vnd.github+json"})
372319
if err != nil {
373-
return nil, err
320+
return err
374321
}
322+
defer resp.Body.Close()
375323

376-
if resp.StatusCode != http.StatusOK {
377-
resp.Body.Close()
378-
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
324+
return json.NewDecoder(resp.Body).Decode(target)
325+
}
326+
327+
// download opens an HTTP connection to the given URL and returns the
328+
// response body as an io.ReadCloser. The caller is responsible for closing it.
329+
func (r *Registry) download(ctx context.Context, url string) (io.ReadCloser, error) {
330+
resp, err := r.doGet(ctx, url, nil)
331+
if err != nil {
332+
return nil, err
379333
}
380334

381335
return resp.Body, nil

0 commit comments

Comments
 (0)