Skip to content

Commit e9e7ba5

Browse files
committed
fix(providers): decode shared local probes per provider
Address the latest Copilot review comments on PR #30. Local provider discovery still dedupes network requests by probe URL, but the shared work now stops at fetching the raw HTTP response. Each preset decodes that raw response separately with its own provider identity, so shared endpoints such as llama.cpp and LocalAI cannot inherit whichever provider happened to start the probe first. This preserves the one-request-per-endpoint behavior while keeping schema errors and model decoding provider-specific. The local discovery UI type now models status as an actual closed union with an explicit unknown fallback instead of unioning the literals with string, which collapsed the type to plain string and removed useful compile-time checking. Regression coverage asserts that shared invalid endpoint responses are requested once but produce provider-specific decode errors for every preset using that endpoint. Verified with: - cd ui && bun run typecheck - cd ui && bun run test -- api provider-utils ProvidersView - cd ui && bun run test - GOCACHE=/Users/chicoxyzzy/dev/hecate/.gocache go test ./internal/api -run 'TestDiscoverLocalProviders|TestLocalProviderProbeURL' - GOCACHE=/Users/chicoxyzzy/dev/hecate/.gocache go test ./internal/api
1 parent f324478 commit e9e7ba5

3 files changed

Lines changed: 65 additions & 15 deletions

File tree

internal/api/handler_local_provider_discovery.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package api
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
78
"fmt"
9+
"io"
810
"net/http"
911
"net/url"
1012
"os/exec"
@@ -29,6 +31,12 @@ type localHTTPProbeResult struct {
2931
err string
3032
}
3133

34+
type localHTTPFetchResult struct {
35+
statusCode int
36+
body []byte
37+
err string
38+
}
39+
3240
type localProviderProbe struct {
3341
provider config.BuiltInProvider
3442
probeURL string
@@ -42,7 +50,7 @@ type localProviderDiscoveryResult struct {
4250

4351
type localHTTPProbeTask struct {
4452
done chan struct{}
45-
result localHTTPProbeResult
53+
result localHTTPFetchResult
4654
}
4755

4856
func (h *Handler) HandleLocalProviderDiscovery(w http.ResponseWriter, r *http.Request) {
@@ -111,7 +119,7 @@ func discoverLocalProviderPairsConcurrently(ctx context.Context, providers []loc
111119
var probesMu sync.Mutex
112120
var wg sync.WaitGroup
113121

114-
getProbe := func(probeURL, providerID string) *localHTTPProbeTask {
122+
getProbe := func(probeURL string) *localHTTPProbeTask {
115123
probesMu.Lock()
116124
defer probesMu.Unlock()
117125
if task, ok := probes[probeURL]; ok {
@@ -122,7 +130,7 @@ func discoverLocalProviderPairsConcurrently(ctx context.Context, providers []loc
122130
go func() {
123131
probeCtx, cancel := context.WithTimeout(ctx, localProviderDiscoveryTimeout)
124132
defer cancel()
125-
task.result = probeLocalProviderHTTP(probeCtx, client, probeURL, providerID)
133+
task.result = fetchLocalProviderHTTP(probeCtx, client, probeURL)
126134
close(task.done)
127135
}()
128136
return task
@@ -133,11 +141,11 @@ func discoverLocalProviderPairsConcurrently(ctx context.Context, providers []loc
133141
go func() {
134142
defer wg.Done()
135143
command, path := findLocalProviderCommand(entry.provider.ID, lookPath)
136-
task := getProbe(entry.probeURL, entry.provider.ID)
144+
task := getProbe(entry.probeURL)
137145
var httpResult localHTTPProbeResult
138146
select {
139147
case <-task.done:
140-
httpResult = task.result
148+
httpResult = decodeLocalProviderHTTPProbe(task.result, entry.provider.ID)
141149
case <-ctx.Done():
142150
httpResult = localHTTPProbeResult{err: compactLocalProbeError(ctx.Err())}
143151
}
@@ -194,36 +202,47 @@ func localProviderProbeURL(provider config.BuiltInProvider) string {
194202
return base + "/models"
195203
}
196204

197-
func probeLocalProviderHTTP(ctx context.Context, client localProviderHTTPDoer, probeURL, providerID string) localHTTPProbeResult {
205+
func fetchLocalProviderHTTP(ctx context.Context, client localProviderHTTPDoer, probeURL string) localHTTPFetchResult {
198206
req, err := http.NewRequestWithContext(ctx, http.MethodGet, probeURL, nil)
199207
if err != nil {
200-
return localHTTPProbeResult{err: err.Error()}
208+
return localHTTPFetchResult{err: err.Error()}
201209
}
202210
resp, err := client.Do(req)
203211
if err != nil {
204-
return localHTTPProbeResult{err: compactLocalProbeError(err)}
212+
return localHTTPFetchResult{err: compactLocalProbeError(err)}
205213
}
206214
defer resp.Body.Close()
207215

208-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
209-
return localHTTPProbeResult{err: fmt.Sprintf("HTTP %d", resp.StatusCode)}
216+
body, err := io.ReadAll(resp.Body)
217+
if err != nil {
218+
return localHTTPFetchResult{statusCode: resp.StatusCode, err: err.Error()}
219+
}
220+
return localHTTPFetchResult{statusCode: resp.StatusCode, body: body}
221+
}
222+
223+
func decodeLocalProviderHTTPProbe(fetch localHTTPFetchResult, providerID string) localHTTPProbeResult {
224+
if fetch.err != "" {
225+
return localHTTPProbeResult{err: fetch.err}
226+
}
227+
if fetch.statusCode < 200 || fetch.statusCode >= 300 {
228+
return localHTTPProbeResult{err: fmt.Sprintf("HTTP %d", fetch.statusCode)}
210229
}
211230

212-
models, err := decodeLocalProviderModels(resp, providerID)
231+
models, err := decodeLocalProviderModels(fetch.body, providerID)
213232
if err != nil {
214233
return localHTTPProbeResult{err: err.Error()}
215234
}
216235
return localHTTPProbeResult{available: true, models: models}
217236
}
218237

219-
func decodeLocalProviderModels(resp *http.Response, providerID string) ([]string, error) {
238+
func decodeLocalProviderModels(body []byte, providerID string) ([]string, error) {
220239
if providerID == "ollama" {
221240
var payload struct {
222241
Models []struct {
223242
Name string `json:"name"`
224243
} `json:"models"`
225244
}
226-
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
245+
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&payload); err != nil {
227246
return nil, fmt.Errorf("invalid %s response: %w", providerID, err)
228247
}
229248
models := make([]string, 0, len(payload.Models))
@@ -240,7 +259,7 @@ func decodeLocalProviderModels(resp *http.Response, providerID string) ([]string
240259
ID string `json:"id"`
241260
} `json:"data"`
242261
}
243-
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
262+
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&payload); err != nil {
244263
return nil, fmt.Errorf("invalid %s response: %w", providerID, err)
245264
}
246265
models := make([]string, 0, len(payload.Data))

internal/api/handler_local_provider_discovery_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,37 @@ func TestDiscoverLocalProvidersRejectsInvalidHTTPProbeBody(t *testing.T) {
293293
}
294294
}
295295

296+
func TestDiscoverLocalProvidersDecodesSharedHTTPProbePerProvider(t *testing.T) {
297+
t.Parallel()
298+
299+
providers := []config.BuiltInProvider{
300+
{ID: "llamacpp", Name: "llama.cpp", Kind: "local", BaseURL: "http://127.0.0.1:8080/v1"},
301+
{ID: "localai", Name: "LocalAI", Kind: "local", BaseURL: "http://127.0.0.1:8080/v1"},
302+
}
303+
rt := &localProviderRoundTrip{
304+
body: map[string]string{
305+
"http://127.0.0.1:8080/v1/models": `not-json`,
306+
},
307+
}
308+
309+
items := discoverLocalProviders(context.Background(), providers, missingLocalCommand, rt)
310+
311+
if len(items) != 2 {
312+
t.Fatalf("items = %d, want 2", len(items))
313+
}
314+
if got := rt.calls["http://127.0.0.1:8080/v1/models"]; got != 1 {
315+
t.Fatalf("shared endpoint request count = %d, want 1", got)
316+
}
317+
for _, item := range items {
318+
if item.HTTPAvailable {
319+
t.Fatalf("%s HTTPAvailable = true, want false", item.PresetID)
320+
}
321+
if !strings.Contains(item.Error, "invalid "+item.PresetID+" response") {
322+
t.Fatalf("%s error = %q, want provider-specific decode error", item.PresetID, item.Error)
323+
}
324+
}
325+
}
326+
296327
func TestLocalProviderProbeURLUsesOllamaNativeTagsEndpoint(t *testing.T) {
297328
t.Parallel()
298329

ui/src/types/runtime.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ export type LocalProviderDiscoveryRecord = {
175175
name: string;
176176
base_url: string;
177177
probe_url: string;
178-
status: "running" | "installed" | "not_detected" | "error" | string;
178+
status: "running" | "installed" | "not_detected" | "error" | "unknown";
179179
command?: string;
180180
command_available: boolean;
181181
command_path?: string;

0 commit comments

Comments
 (0)