Skip to content

Commit 892f88a

Browse files
committed
Refactor session ID handling in response capture and invocation methods. Fixes #7573.
1 parent a4f18b0 commit 892f88a

File tree

3 files changed

+147
-14
lines changed

3 files changed

+147
-14
lines changed

cli/azd/extensions/azure.ai.agents/internal/cmd/helpers.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,12 @@ func contextMap(agentCtx *AgentLocalContext, field string) map[string]string {
199199
// printSessionStatus prints the session line for the invoke banner.
200200
// label is the formatted prefix (e.g. "Session: " or "Session: ").
201201
func printSessionStatus(label, sid string) {
202-
if sid != "" {
203-
fmt.Printf("%s%s\n", label, sid)
204-
} else {
205-
fmt.Printf("%s(new — server will assign)\n", label)
206-
}
202+
fmt.Printf("%s%s\n", label, sid)
207203
}
208204

209205
// captureResponseSession reads the x-agent-session-id header from a response
210-
// and saves it when the caller had no pre-existing session (sid == "").
211-
// label is the formatted prefix for printing (e.g. "Session: ").
206+
// and updates the persisted session when the server returns a different ID
207+
// than what the client sent (or assigns one when none was sent).
212208
func captureResponseSession(
213209
ctx context.Context,
214210
azdClient *azdext.AzdClient,
@@ -217,12 +213,19 @@ func captureResponseSession(
217213
resp *http.Response,
218214
label string,
219215
) {
220-
if sid != "" || azdClient == nil {
216+
if azdClient == nil {
221217
return
222218
}
223-
if newSid := resp.Header.Get("x-agent-session-id"); newSid != "" {
224-
saveContextValue(ctx, azdClient, agentName, newSid, "sessions")
219+
newSid := resp.Header.Get("x-agent-session-id")
220+
if newSid == "" || newSid == sid {
221+
return
222+
}
223+
// Server assigned or reassigned a session ID — update persisted value.
224+
saveContextValue(ctx, azdClient, agentName, newSid, "sessions")
225+
if sid == "" {
225226
fmt.Printf("%s%s (assigned by server)\n", label, newSid)
227+
} else {
228+
fmt.Printf("%s%s (reassigned by server)\n", label, newSid)
226229
}
227230
}
228231

cli/azd/extensions/azure.ai.agents/internal/cmd/helpers_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package cmd
55

66
import (
7+
"net/http"
78
"os"
89
"path/filepath"
910
"testing"
@@ -178,3 +179,131 @@ func TestToServiceKey(t *testing.T) {
178179
})
179180
}
180181
}
182+
183+
func TestCaptureResponseSession_NilClient(t *testing.T) {
184+
t.Parallel()
185+
186+
tests := []struct {
187+
name string
188+
sid string
189+
headerVal string
190+
}{
191+
{name: "no header", sid: "", headerVal: ""},
192+
{name: "header present but nil client", sid: "", headerVal: "server-session-abc"},
193+
{name: "client sid set with header", sid: "existing-session", headerVal: "server-session-abc"},
194+
}
195+
196+
for _, tt := range tests {
197+
t.Run(tt.name, func(t *testing.T) {
198+
t.Parallel()
199+
200+
resp := &http.Response{Header: http.Header{}}
201+
if tt.headerVal != "" {
202+
resp.Header.Set("x-agent-session-id", tt.headerVal)
203+
}
204+
205+
// Must not panic with nil azdClient.
206+
captureResponseSession(t.Context(), nil, "test-agent", tt.sid, resp, "Session: ")
207+
})
208+
}
209+
}
210+
211+
func TestLoadSaveLocalContext(t *testing.T) {
212+
t.Parallel()
213+
214+
t.Run("round trip", func(t *testing.T) {
215+
t.Parallel()
216+
217+
dir := t.TempDir()
218+
configPath := filepath.Join(dir, ConfigFile)
219+
220+
agentCtx := &AgentLocalContext{
221+
AgentName: "my-agent",
222+
Sessions: map[string]string{"agent1": "sess-123"},
223+
}
224+
225+
if err := saveLocalContext(agentCtx, configPath); err != nil {
226+
t.Fatalf("saveLocalContext failed: %v", err)
227+
}
228+
229+
loaded := loadLocalContext(configPath)
230+
if loaded.AgentName != "my-agent" {
231+
t.Errorf("AgentName = %q, want %q", loaded.AgentName, "my-agent")
232+
}
233+
if loaded.Sessions["agent1"] != "sess-123" {
234+
t.Errorf("Sessions[agent1] = %q, want %q", loaded.Sessions["agent1"], "sess-123")
235+
}
236+
})
237+
238+
t.Run("missing file returns empty context", func(t *testing.T) {
239+
t.Parallel()
240+
241+
loaded := loadLocalContext(filepath.Join(t.TempDir(), "nonexistent.json"))
242+
if loaded.Sessions != nil {
243+
t.Errorf("expected nil Sessions for missing file, got %v", loaded.Sessions)
244+
}
245+
})
246+
247+
t.Run("corrupt file returns empty context", func(t *testing.T) {
248+
t.Parallel()
249+
250+
dir := t.TempDir()
251+
configPath := filepath.Join(dir, ConfigFile)
252+
if err := os.WriteFile(configPath, []byte("{bad json"), 0600); err != nil {
253+
t.Fatalf("failed to write corrupt file: %v", err)
254+
}
255+
256+
loaded := loadLocalContext(configPath)
257+
if loaded.Sessions != nil {
258+
t.Errorf("expected nil Sessions for corrupt file, got %v", loaded.Sessions)
259+
}
260+
})
261+
}
262+
263+
func TestContextMap(t *testing.T) {
264+
t.Parallel()
265+
266+
tests := []struct {
267+
name string
268+
field string
269+
}{
270+
{name: "sessions", field: "sessions"},
271+
{name: "conversations", field: "conversations"},
272+
{name: "invocations", field: "invocations"},
273+
{name: "unknown", field: "unknown"},
274+
}
275+
276+
for _, tt := range tests {
277+
t.Run(tt.name, func(t *testing.T) {
278+
t.Parallel()
279+
280+
agentCtx := &AgentLocalContext{}
281+
m := contextMap(agentCtx, tt.field)
282+
if m == nil {
283+
t.Fatal("expected non-nil map")
284+
}
285+
m["key"] = "value"
286+
287+
// For known fields, the map should be stored on the struct.
288+
switch tt.field {
289+
case "sessions":
290+
if agentCtx.Sessions["key"] != "value" {
291+
t.Error("sessions map not stored on struct")
292+
}
293+
case "conversations":
294+
if agentCtx.Conversations["key"] != "value" {
295+
t.Error("conversations map not stored on struct")
296+
}
297+
case "invocations":
298+
if agentCtx.Invocations["key"] != "value" {
299+
t.Error("invocations map not stored on struct")
300+
}
301+
case "unknown":
302+
// Detached map — verify it doesn't affect any struct field.
303+
if agentCtx.Sessions != nil || agentCtx.Conversations != nil || agentCtx.Invocations != nil {
304+
t.Error("unknown field should not initialize struct maps")
305+
}
306+
}
307+
})
308+
}
309+
}

cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error {
356356
}
357357

358358
// Session ID — routes to the same microVM container instance.
359-
// When empty, let the server assign one.
360-
sid, err := resolveStoredID(ctx, azdClient, name, a.flags.session, a.flags.newSession, "sessions", false)
359+
// Generated client-side (UUID) so it is always persisted locally, even if the request fails.
360+
sid, err := resolveStoredID(ctx, azdClient, name, a.flags.session, a.flags.newSession, "sessions", true)
361361
if err != nil {
362362
return err
363363
}
@@ -542,8 +542,9 @@ func (a *InvokeAction) invocationsRemote(ctx context.Context) error {
542542
return err
543543
}
544544

545-
// Session ID — routes to the same container instance
546-
sid, err := resolveStoredID(ctx, azdClient, name, a.flags.session, a.flags.newSession, "sessions", false)
545+
// Session ID — routes to the same container instance.
546+
// Generated client-side (UUID) so it is always persisted locally, even if the request fails.
547+
sid, err := resolveStoredID(ctx, azdClient, name, a.flags.session, a.flags.newSession, "sessions", true)
547548
if err != nil {
548549
return err
549550
}

0 commit comments

Comments
 (0)