From 5a3601dd41153f7cd1a18a7454b0ab3bc7483712 Mon Sep 17 00:00:00 2001 From: xunzhuo Date: Wed, 15 Apr 2026 19:41:49 +0800 Subject: [PATCH 1/2] feat: support session aware routing Signed-off-by: xunzhuo --- config/README.md | 2 +- config/algorithm/selection/session-aware.yaml | 9 + config/config.yaml | 42 ++- config/signal/session/runtime-facts.yaml | 22 ++ dashboard/backend/handlers/topology.go | 1 + .../backend/handlers/topology_response.go | 1 + dashboard/frontend/src/pages/ConfigPage.tsx | 3 + .../src/pages/ConfigPageDecisionsSection.tsx | 4 + .../src/pages/ConfigPageSignalsSection.tsx | 99 +++++- .../frontend/src/pages/DashboardPage.tsx | 2 + .../pages/configPageRouterDefaultsSupport.ts | 5 +- .../frontend/src/pages/configPageSupport.ts | 9 + .../components/CustomNodes/AlgorithmNode.tsx | 13 + .../frontend/src/pages/topology/constants.ts | 2 + .../frontend/src/pages/topology/types.ts | 25 ++ .../pages/topology/utils/topologyParser.ts | 20 +- dashboard/frontend/src/types/config.ts | 12 +- deploy/recipes/session-state.dsl | 23 -- deploy/recipes/session-state.yaml | 16 - docs/agent/plans/README.md | 1 + ...-session-aware-routing-convergence-loop.md | 63 ++++ .../classification/classifier_signal_eval.go | 6 +- .../pkg/config/canonical_config.go | 12 +- .../pkg/config/canonical_export.go | 10 +- src/semantic-router/pkg/config/config.go | 11 +- .../pkg/config/decision_config.go | 1 + .../pkg/config/fragment_catalog_test.go | 1 + .../reference_config_public_surface_test.go | 22 +- .../reference_config_routing_surface_test.go | 2 + .../pkg/config/routing_surface_catalog.go | 2 + .../pkg/config/selection_config.go | 11 +- .../pkg/config/session_selection_config.go | 26 ++ .../pkg/config/session_signal_config.go | 44 +++ .../pkg/config/session_state_config.go | 14 - .../pkg/config/signal_config.go | 1 + src/semantic-router/pkg/config/validator.go | 3 + .../pkg/config/validator_decision.go | 11 + .../pkg/config/validator_projection.go | 10 + .../pkg/config/validator_projection_test.go | 42 +++ .../pkg/config/validator_session.go | 83 +++++ src/semantic-router/pkg/decision/engine.go | 1 + src/semantic-router/pkg/dsl/ast.go | 22 -- src/semantic-router/pkg/dsl/compiler.go | 63 +++- src/semantic-router/pkg/dsl/decompiler.go | 130 +++++++- src/semantic-router/pkg/dsl/dsl_test.go | 211 ++++-------- .../dsl/maintained_asset_roundtrip_test.go | 23 -- src/semantic-router/pkg/dsl/parser.go | 29 +- .../pkg/dsl/routing_contract.go | 22 +- .../pkg/dsl/validator_conflicts.go | 41 --- src/semantic-router/pkg/extproc/recorder.go | 41 ++- .../pkg/extproc/recorder_test.go | 20 ++ .../pkg/extproc/req_filter_classification.go | 18 +- .../req_filter_classification_runtime.go | 5 + .../req_filter_classification_session.go | 98 ++++++ .../req_filter_classification_signal.go | 2 + .../pkg/extproc/request_context.go | 1 + src/semantic-router/pkg/extproc/router.go | 4 +- .../pkg/extproc/router_build.go | 6 +- .../pkg/extproc/router_selection.go | 51 ++- .../pkg/extproc/router_selection_context.go | 111 ++++++ .../pkg/k8s/testdata/output/01-basic.yaml | 2 + .../k8s/testdata/output/02-keyword-only.yaml | 2 + .../testdata/output/03-embedding-only.yaml | 2 + .../k8s/testdata/output/04-domain-only.yaml | 2 + .../testdata/output/05-keyword-embedding.yaml | 2 + .../testdata/output/06-keyword-domain.yaml | 2 + .../testdata/output/07-domain-embedding.yaml | 2 + .../output/08-keyword-embedding-domain.yaml | 2 + .../testdata/output/09-keyword-plugin.yaml | 2 + .../testdata/output/10-embedding-plugin.yaml | 2 + .../k8s/testdata/output/11-domain-plugin.yaml | 2 + .../output/12-keyword-embedding-plugin.yaml | 2 + .../output/13-keyword-domain-plugin.yaml | 2 + .../output/14-domain-embedding-plugin.yaml | 2 + .../15-keyword-embedding-domain-plugin.yaml | 2 + ...16-keyword-embedding-domain-no-plugin.yaml | 2 + .../pkg/routerreplay/recorder.go | 39 ++- .../pkg/routerreplay/recorder_test.go | 48 +-- .../pkg/routerreplay/store/postgres.go | 58 +++- .../routerreplay/store/postgres_record_row.go | 15 +- .../pkg/routerreplay/store/store.go | 6 + src/semantic-router/pkg/selection/factory.go | 21 ++ .../pkg/selection/lookuptable/builder.go | 24 +- src/semantic-router/pkg/selection/selector.go | 28 ++ .../pkg/selection/session_aware.go | 315 ++++++++++++++++++ .../pkg/selection/tier_declarations.go | 10 + src/vllm-sr/cli/algorithms.py | 1 + src/vllm-sr/cli/config_contract.py | 1 + src/vllm-sr/cli/models.py | 1 + src/vllm-sr/cli/validator.py | 8 + tools/agent/repo-manifest.yaml | 4 + website/docs/tutorials/algorithm/overview.md | 1 + .../algorithm/selection/session-aware.md | 85 +++++ .../tutorials/signal/heuristic/session.md | 80 +++++ website/docs/tutorials/signal/overview.md | 1 + 95 files changed, 1853 insertions(+), 510 deletions(-) create mode 100644 config/algorithm/selection/session-aware.yaml create mode 100644 config/signal/session/runtime-facts.yaml delete mode 100644 deploy/recipes/session-state.dsl delete mode 100644 deploy/recipes/session-state.yaml create mode 100644 docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md create mode 100644 src/semantic-router/pkg/config/session_selection_config.go create mode 100644 src/semantic-router/pkg/config/session_signal_config.go delete mode 100644 src/semantic-router/pkg/config/session_state_config.go create mode 100644 src/semantic-router/pkg/config/validator_session.go create mode 100644 src/semantic-router/pkg/extproc/req_filter_classification_session.go create mode 100644 src/semantic-router/pkg/extproc/router_selection_context.go create mode 100644 src/semantic-router/pkg/selection/session_aware.go create mode 100644 website/docs/tutorials/algorithm/selection/session-aware.md create mode 100644 website/docs/tutorials/signal/heuristic/session.md diff --git a/config/README.md b/config/README.md index cb3451fb05..0cbe037055 100644 --- a/config/README.md +++ b/config/README.md @@ -41,7 +41,7 @@ Decision fragments may reference `modelRefs[].lora_name`, but those adapter name `config/algorithm/` is organized by routing policy: - `looper/`: multi-model execution policies such as `confidence`, `ratings`, and `remom` -- `selection/`: candidate-selection policies such as `elo`, `router_dc`, `automix`, and `latency_aware` +- `selection/`: candidate-selection policies such as `elo`, `router_dc`, `automix`, `session_aware`, and `latency_aware` Each supported algorithm now has its own tutorial page under `website/docs/tutorials/algorithm/`. diff --git a/config/algorithm/selection/session-aware.yaml b/config/algorithm/selection/session-aware.yaml new file mode 100644 index 0000000000..083e761501 --- /dev/null +++ b/config/algorithm/selection/session-aware.yaml @@ -0,0 +1,9 @@ +algorithm: + type: session_aware + session_aware: + fallback_method: hybrid + min_turns_before_switch: 2 + stay_bias: 0.3 + quality_gap_multiplier: 1.15 + handoff_penalty_weight: 0.9 + remaining_turn_weight: 0.45 diff --git a/config/config.yaml b/config/config.yaml index 736d073c7d..7f487ddd4e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -467,22 +467,6 @@ routing: - name: support_escalated gte: 0.45 - session_states: - - name: session_routing - fields: - - name: turn_number - type: int - - name: current_model - type: string - - name: cumulative_cost_usd - type: float - - name: retry_count_ema - type: float - - name: quality_score_ema - type: float - - name: kv_cache_warm - type: float - decisions: - name: static_business_route description: Static fallback for standard business traffic. @@ -787,6 +771,32 @@ routing: cost_weight: 0.1 quality_gap_threshold: 0.08 normalize_scores: true + - name: session_continuation_route + description: Session-aware route that prefers staying on the current technical model mid-conversation. + priority: 136 + rules: + operator: AND + conditions: + - type: session + name: session_present + - type: domain + name: "computer science" + modelRefs: + - model: qwen3-8b + use_reasoning: false + - model: qwen3-32b + lora_name: computer-science-expert + use_reasoning: true + algorithm: + type: session_aware + session_aware: + fallback_method: hybrid + min_turns_before_switch: 2 + stay_bias: 0.3 + quality_gap_multiplier: 1.15 + handoff_penalty_weight: 0.9 + remaining_turn_weight: 0.45 + - name: spanish_rl_route description: RL-driven route for Spanish-language traffic with Milvus RAG. priority: 135 diff --git a/config/signal/session/runtime-facts.yaml b/config/signal/session/runtime-facts.yaml new file mode 100644 index 0000000000..31b25a6e56 --- /dev/null +++ b/config/signal/session/runtime-facts.yaml @@ -0,0 +1,22 @@ +routing: + signals: + session: + - name: session_present + description: Requests that belong to an existing multi-turn conversation. + fact: session_present + predicate: + gte: 1 + - name: warm_cache_continuation + description: Prefer staying on the warmed model when the same conversation continues. + fact: cache_warmth + previous_model: qwen3-8b + predicate: + gte: 0.6 + - name: expensive_handoff + description: Detect costly mid-session upgrades into the premium coding model. + fact: handoff_penalty + intent_or_domain: computer science + previous_model: qwen3-8b + candidate_model: qwen3-32b + predicate: + gte: 0.15 diff --git a/dashboard/backend/handlers/topology.go b/dashboard/backend/handlers/topology.go index 3f4a71d72b..dc4689c40f 100644 --- a/dashboard/backend/handlers/topology.go +++ b/dashboard/backend/handlers/topology.go @@ -137,6 +137,7 @@ type RouterMatchedSignals struct { Structure []string `json:"structure,omitempty"` Complexity []string `json:"complexity,omitempty"` Modality []string `json:"modality,omitempty"` + Session []string `json:"session,omitempty"` Authz []string `json:"authz,omitempty"` Jailbreak []string `json:"jailbreak,omitempty"` PII []string `json:"pii,omitempty"` diff --git a/dashboard/backend/handlers/topology_response.go b/dashboard/backend/handlers/topology_response.go index d65dca0d18..3585d4d5d8 100644 --- a/dashboard/backend/handlers/topology_response.go +++ b/dashboard/backend/handlers/topology_response.go @@ -71,6 +71,7 @@ func topologySignalMappings(matchedSignals *RouterMatchedSignals) []topologySign {signalType: "structure", names: matchedSignals.Structure, defaultConfidence: 1.0, reason: "Structure rule matched", addPath: true}, {signalType: "complexity", names: matchedSignals.Complexity, defaultConfidence: 0.9, reason: "Complexity level matched", addPath: true}, {signalType: "modality", names: matchedSignals.Modality, defaultConfidence: 1.0, reason: "Modality signal matched", addPath: true}, + {signalType: "session", names: matchedSignals.Session, defaultConfidence: 1.0, reason: "Session signal matched", addPath: true}, {signalType: "authz", names: matchedSignals.Authz, defaultConfidence: 1.0, reason: "Authorization signal matched", addPath: true}, {signalType: "jailbreak", names: matchedSignals.Jailbreak, defaultConfidence: 1.0, reason: "Jailbreak signal matched", addPath: true}, {signalType: "pii", names: matchedSignals.PII, defaultConfidence: 1.0, reason: "PII signal matched", addPath: true}, diff --git a/dashboard/frontend/src/pages/ConfigPage.tsx b/dashboard/frontend/src/pages/ConfigPage.tsx index 056791d793..d739280437 100644 --- a/dashboard/frontend/src/pages/ConfigPage.tsx +++ b/dashboard/frontend/src/pages/ConfigPage.tsx @@ -272,6 +272,9 @@ const ConfigPage: React.FC = ({ activeSection = 'global-config' case 'Modality': cfg.signals.modality = (cfg.signals.modality || []).filter(s => s.name !== targetName) break + case 'Session': + cfg.signals.session = (cfg.signals.session || []).filter(s => s.name !== targetName) + break case 'Authz': cfg.signals.role_bindings = (cfg.signals.role_bindings || []).filter(s => s.name !== targetName) break diff --git a/dashboard/frontend/src/pages/ConfigPageDecisionsSection.tsx b/dashboard/frontend/src/pages/ConfigPageDecisionsSection.tsx index bc1a41789f..3776ecf850 100644 --- a/dashboard/frontend/src/pages/ConfigPageDecisionsSection.tsx +++ b/dashboard/frontend/src/pages/ConfigPageDecisionsSection.tsx @@ -259,8 +259,12 @@ export default function ConfigPageDecisionsSection({ ]) case 'modality': return config?.signals?.modality?.map((m) => m.name) || [] + case 'session': + return config?.signals?.session?.map((rule) => rule.name) || [] case 'authz': return config?.signals?.role_bindings?.map((binding) => binding.name) || [] + case 'kb': + return config?.signals?.kb?.map((rule) => rule.name) || [] case 'jailbreak': return config?.signals?.jailbreak?.map((rule) => rule.name) || [] case 'pii': diff --git a/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx b/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx index 4013044d4e..2dc607dbbe 100644 --- a/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx +++ b/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx @@ -18,6 +18,7 @@ import type { KeywordSignal, LanguageSignal, ModalitySignal, + SessionSignal, PIISignal, PreferenceSignal, ReaskSignal, @@ -225,6 +226,18 @@ export default function ConfigPageSignalsSection({ }) }) + effectiveSignals?.session?.forEach(session => { + const scope = [session.fact, session.previous_model, session.candidate_model] + .filter(Boolean) + .join(' • ') + allSignals.push({ + name: session.name, + type: 'Session', + summary: scope || session.description || 'Session-derived runtime fact', + rawData: session, + }) + }) + effectiveSignals?.role_bindings?.forEach(binding => { const subjectCount = binding.subjects?.length || 0 allSignals.push({ @@ -519,6 +532,22 @@ export default function ConfigPageSignalsSection({ { label: 'Description', value: signal.rawData.description || 'N/A', fullWidth: true }, ] }) + } else if (signal.type === 'Session') { + sections.push({ + title: 'Session Signal', + fields: [ + { label: 'Fact', value: signal.rawData.fact || 'N/A' }, + { label: 'Intent / Domain', value: signal.rawData.intent_or_domain || 'Any' }, + { label: 'Previous Model', value: signal.rawData.previous_model || 'Any' }, + { label: 'Candidate Model', value: signal.rawData.candidate_model || 'Any' }, + { + label: 'Predicate', + value: signal.rawData.predicate ? JSON.stringify(signal.rawData.predicate, null, 2) : 'Always match', + fullWidth: true, + }, + { label: 'Description', value: signal.rawData.description || 'N/A', fullWidth: true }, + ] + }) } else if (signal.type === 'Authz') { sections.push({ title: 'Role Binding', @@ -617,6 +646,11 @@ export default function ConfigPageSignalsSection({ easy_candidates: '', composer_operator: 'AND', composer_conditions: '', + session_fact: '', + session_predicate: JSON.stringify({ gte: 1 }, null, 2), + session_intent_or_domain: '', + session_previous_model: '', + session_candidate_model: '', jailbreak_threshold: 0.65, jailbreak_method: 'classifier', include_history: false, @@ -658,6 +692,11 @@ export default function ConfigPageSignalsSection({ easy_candidates: (signal.rawData.easy?.candidates || []).join('\n'), composer_operator: signal.rawData.composer?.operator || 'AND', composer_conditions: signal.rawData.composer?.conditions?.map((c: { type: string; name: string }) => `${c.type}:${c.name}`).join('\n') || '', + session_fact: signal.type === 'Session' ? signal.rawData.fact || '' : '', + session_predicate: signal.type === 'Session' && signal.rawData.predicate ? JSON.stringify(signal.rawData.predicate, null, 2) : defaultForm.session_predicate, + session_intent_or_domain: signal.type === 'Session' ? signal.rawData.intent_or_domain || '' : '', + session_previous_model: signal.type === 'Session' ? signal.rawData.previous_model || '' : '', + session_candidate_model: signal.type === 'Session' ? signal.rawData.candidate_model || '' : '', jailbreak_threshold: signal.rawData.threshold ?? 0.65, jailbreak_method: signal.rawData.method || 'classifier', include_history: !!signal.rawData.include_history, @@ -681,7 +720,7 @@ export default function ConfigPageSignalsSection({ name: 'type', label: 'Type', type: 'select', - options: ['Keywords', 'Embeddings', 'Domain', 'Preference', 'Fact Check', 'User Feedback', 'Reask', 'Language', 'Context', 'Structure', 'Complexity', 'Modality', 'Authz', 'Jailbreak', 'PII', 'KB'], + options: ['Keywords', 'Embeddings', 'Domain', 'Preference', 'Fact Check', 'User Feedback', 'Reask', 'Language', 'Context', 'Structure', 'Complexity', 'Modality', 'Session', 'Authz', 'Jailbreak', 'PII', 'KB'], required: true, description: 'Fields are validated based on the selected type.' }, @@ -861,6 +900,43 @@ export default function ConfigPageSignalsSection({ description: 'Phrases representing easy/simple queries', shouldHide: conditionallyHideFieldExceptType('Complexity') }, + { + name: 'session_fact', + label: 'Fact (session only)', + type: 'text', + placeholder: 'session_present', + description: 'Runtime-derived session fact name consumed by the router.', + shouldHide: conditionallyHideFieldExceptType('Session') + }, + { + name: 'session_predicate', + label: 'Predicate (session only)', + type: 'textarea', + placeholder: '{\n "gte": 1\n}', + description: 'Optional numeric predicate JSON applied to the runtime fact value.', + shouldHide: conditionallyHideFieldExceptType('Session') + }, + { + name: 'session_intent_or_domain', + label: 'Intent / Domain (session only)', + type: 'text', + placeholder: 'computer science', + shouldHide: conditionallyHideFieldExceptType('Session') + }, + { + name: 'session_previous_model', + label: 'Previous Model (session only)', + type: 'text', + placeholder: 'qwen3-8b', + shouldHide: conditionallyHideFieldExceptType('Session') + }, + { + name: 'session_candidate_model', + label: 'Candidate Model (session only)', + type: 'text', + placeholder: 'qwen3-32b', + shouldHide: conditionallyHideFieldExceptType('Session') + }, { name: 'role', label: 'Role (authz only)', @@ -1231,6 +1307,27 @@ export default function ConfigPageSignalsSection({ ] break } + case 'Session': { + const fact = (formData.session_fact || '').trim() + if (!fact) { + throw new Error('Fact is required for session signals.') + } + const predicateText = (formData.session_predicate || '').trim() + const predicate = predicateText ? JSON.parse(predicateText) : undefined + newConfig.signals.session = [ + ...(newConfig.signals.session || []), + { + name, + description: formData.description || undefined, + fact, + predicate, + intent_or_domain: (formData.session_intent_or_domain || '').trim() || undefined, + previous_model: (formData.session_previous_model || '').trim() || undefined, + candidate_model: (formData.session_candidate_model || '').trim() || undefined, + } + ] + break + } case 'Authz': { const role = (formData.role || '').trim() if (!role) { diff --git a/dashboard/frontend/src/pages/DashboardPage.tsx b/dashboard/frontend/src/pages/DashboardPage.tsx index 03f60c426e..59d85fbfef 100644 --- a/dashboard/frontend/src/pages/DashboardPage.tsx +++ b/dashboard/frontend/src/pages/DashboardPage.tsx @@ -132,9 +132,11 @@ const SIGNAL_COLORS: Record = { context: '#D7BA7D', complexity: '#569CD6', modality: '#D4D4D4', + session: '#7AA2F7', authz: '#F48771', jailbreak: '#F48771', pii: '#FF6B6B', + kb: '#9A7BFF', } const MiniFlowDiagram: React.FC = React.memo(({ signals, decisions, models, plugins }) => { diff --git a/dashboard/frontend/src/pages/configPageRouterDefaultsSupport.ts b/dashboard/frontend/src/pages/configPageRouterDefaultsSupport.ts index 9168587b8f..43a37a3f71 100644 --- a/dashboard/frontend/src/pages/configPageRouterDefaultsSupport.ts +++ b/dashboard/frontend/src/pages/configPageRouterDefaultsSupport.ts @@ -700,7 +700,7 @@ function fieldsForKey(key: RouterSystemKey): FieldConfig[] { case 'model_selection': return [ { name: 'enabled', label: 'Enable Model Selection', type: 'boolean' }, - { name: 'default_algorithm', label: 'Method', type: 'select', options: ['knn', 'kmeans', 'svm', 'elo', 'router_dc', 'automix', 'hybrid'], required: true }, + { name: 'default_algorithm', label: 'Method', type: 'select', options: ['knn', 'kmeans', 'svm', 'elo', 'router_dc', 'automix', 'hybrid', 'session_aware'], required: true }, { name: 'models_path', label: 'ML Models Path', type: 'text', placeholder: 'models/model_selection' }, { name: 'knn', label: 'KNN Config (JSON)', type: 'json' }, { name: 'kmeans', label: 'KMeans Config (JSON)', type: 'json' }, @@ -709,6 +709,7 @@ function fieldsForKey(key: RouterSystemKey): FieldConfig[] { { name: 'router_dc', label: 'RouterDC Config (JSON)', type: 'json' }, { name: 'automix', label: 'AutoMix Config (JSON)', type: 'json' }, { name: 'hybrid', label: 'Hybrid Config (JSON)', type: 'json' }, + { name: 'session_aware', label: 'Session-Aware Config (JSON)', type: 'json' }, ] case 'api': return [{ name: 'batch_classification', label: 'Batch Classification (JSON)', type: 'json', placeholder: '{"metrics":{"enabled":true}}' }] @@ -775,6 +776,7 @@ function editDataForKey(key: RouterSystemKey, data: unknown): EditFormData { router_dc: asObject(selection?.router_dc) || {}, automix: asObject(selection?.automix) || {}, hybrid: asObject(selection?.hybrid) || {}, + session_aware: asObject(selection?.session_aware) || {}, } } const objectData = asObject(data) @@ -829,6 +831,7 @@ function saveForKey(key: RouterSystemKey, data: EditFormData): Partial>(({ data }) => { } return parts.length > 0 ? parts.join(', ') : null } + if (algorithm.type === 'session_aware' && algorithm.session_aware) { + const parts: string[] = [] + if (algorithm.session_aware.fallback_method) { + parts.push(`fallback ${algorithm.session_aware.fallback_method}`) + } + if (algorithm.session_aware.min_turns_before_switch !== undefined) { + parts.push(`min turns ${algorithm.session_aware.min_turns_before_switch}`) + } + if (algorithm.session_aware.stay_bias !== undefined) { + parts.push(`stay bias ${algorithm.session_aware.stay_bias}`) + } + return parts.length > 0 ? parts.join(', ') : null + } return null } diff --git a/dashboard/frontend/src/pages/topology/constants.ts b/dashboard/frontend/src/pages/topology/constants.ts index d38b5fed1d..4245d0d6e6 100644 --- a/dashboard/frontend/src/pages/topology/constants.ts +++ b/dashboard/frontend/src/pages/topology/constants.ts @@ -16,6 +16,7 @@ export const SIGNAL_ICONS: Record = { structure: 'STR', complexity: 'CPX', modality: 'MOD', + session: 'SES', authz: 'AUTH', jailbreak: 'JB', pii: 'PII', @@ -124,6 +125,7 @@ export const ALGORITHM_COLORS: Record +export interface SessionSignalConfig { + fact?: string + predicate?: NumericPredicateConfig + intent_or_domain?: string + previous_model?: string + candidate_model?: string +} + export interface AuthzSignalConfig { role?: string } @@ -208,6 +216,15 @@ export interface LatencyAwareAlgorithmConfig { description?: string } +export interface SessionAwareAlgorithmConfig { + fallback_method?: string + min_turns_before_switch?: number + stay_bias?: number + quality_gap_multiplier?: number + handoff_penalty_weight?: number + remaining_turn_weight?: number +} + export interface AutoMixConfig { // POMDP cascade config [key: string]: unknown @@ -654,6 +671,14 @@ export interface ConfigData { ttft_percentile?: number description?: string } + session_aware?: { + fallback_method?: string + min_turns_before_switch?: number + stay_bias?: number + quality_gap_multiplier?: number + handoff_penalty_weight?: number + remaining_turn_weight?: number + } } modelRefs?: Array<{ model: string diff --git a/dashboard/frontend/src/pages/topology/utils/topologyParser.ts b/dashboard/frontend/src/pages/topology/utils/topologyParser.ts index fd0920903d..f33c75d467 100644 --- a/dashboard/frontend/src/pages/topology/utils/topologyParser.ts +++ b/dashboard/frontend/src/pages/topology/utils/topologyParser.ts @@ -421,7 +421,24 @@ function extractSignals(config: ConfigData): SignalConfig[] { }) }) - // 12. Authz / RBAC Role Bindings + // 12. Session Rules + routingSignals?.session?.forEach(rule => { + addSignal({ + type: 'session', + name: rule.name, + description: rule.description, + latency: SIGNAL_LATENCY.session, + config: { + fact: rule.fact, + predicate: rule.predicate, + intent_or_domain: rule.intent_or_domain, + previous_model: rule.previous_model, + candidate_model: rule.candidate_model, + }, + }) + }) + + // 13. Authz / RBAC Role Bindings // From role_bindings (Go/Router format) config.role_bindings?.forEach(rule => { addSignal({ @@ -578,6 +595,7 @@ function extractDecisions(config: ConfigData): DecisionConfig[] { confidence: decision.algorithm.confidence, concurrent: decision.algorithm.concurrent, latency_aware: decision.algorithm.latency_aware, + session_aware: decision.algorithm.session_aware, } : undefined diff --git a/dashboard/frontend/src/types/config.ts b/dashboard/frontend/src/types/config.ts index f05fd6697e..2354ebc080 100644 --- a/dashboard/frontend/src/types/config.ts +++ b/dashboard/frontend/src/types/config.ts @@ -168,6 +168,16 @@ export interface ModalitySignal { description?: string } +export interface SessionSignal { + name: string + description?: string + fact: string + predicate?: NumericPredicate + intent_or_domain?: string + previous_model?: string + candidate_model?: string +} + export interface Subject { kind: 'User' | 'Group' name: string @@ -221,7 +231,7 @@ export interface Signals { // ============================================================================= -export type DecisionConditionType = 'keyword' | 'domain' | 'preference' | 'user_feedback' | 'reask' | 'embedding' | 'fact_check' | 'language' | 'context' | 'structure' | 'complexity' | 'modality' | 'authz' | 'jailbreak' | 'pii' | 'projection' +export type DecisionConditionType = 'keyword' | 'domain' | 'preference' | 'user_feedback' | 'reask' | 'embedding' | 'fact_check' | 'language' | 'context' | 'structure' | 'complexity' | 'modality' | 'session' | 'authz' | 'jailbreak' | 'pii' | 'projection' | 'kb' export interface DecisionCondition { type: DecisionConditionType name: string diff --git a/deploy/recipes/session-state.dsl b/deploy/recipes/session-state.dsl deleted file mode 100644 index ddedd123a7..0000000000 --- a/deploy/recipes/session-state.dsl +++ /dev/null @@ -1,23 +0,0 @@ -# ============================================================================= -# SESSION STATE SCHEMA RECIPE -# ============================================================================= - -SESSION_STATE session_routing { - # Conversation turn counter — incremented by the router each turn. - turn_number: int - - # Name of the model that served the most recent assistant turn. - current_model: string - - # Accumulated spend for this session in USD, updated after each turn. - cumulative_cost_usd: float - - # Exponential moving average of per-turn cost; tracks spend momentum. - retry_count_ema: float - - # Exponential moving average of quality scores from the eval pipeline. - quality_score_ema: float - - # Fraction of the KV cache that was warm on the most recent turn (0–1). - kv_cache_warm: float -} diff --git a/deploy/recipes/session-state.yaml b/deploy/recipes/session-state.yaml deleted file mode 100644 index 55def6816b..0000000000 --- a/deploy/recipes/session-state.yaml +++ /dev/null @@ -1,16 +0,0 @@ -routing: - session_states: - - name: session_routing - fields: - - name: turn_number - type: int - - name: current_model - type: string - - name: cumulative_cost_usd - type: float - - name: retry_count_ema - type: float - - name: quality_score_ema - type: float - - name: kv_cache_warm - type: float diff --git a/docs/agent/plans/README.md b/docs/agent/plans/README.md index 1e698677f2..192ee0fd1b 100644 --- a/docs/agent/plans/README.md +++ b/docs/agent/plans/README.md @@ -87,6 +87,7 @@ Keep the numeric index unique within `docs/agent/plans/`. - [pl-0023-dashboard-dsl-natural-language-mode-loop.md](pl-0023-dashboard-dsl-natural-language-mode-loop.md) - [pl-0024-balance-recipe-simplification-loop.md](pl-0024-balance-recipe-simplification-loop.md) - [pl-0025-prototype-aware-embedding-backed-signal-scoring-loop.md](pl-0025-prototype-aware-embedding-backed-signal-scoring-loop.md) +- [pl-0030-session-aware-routing-convergence-loop.md](pl-0030-session-aware-routing-convergence-loop.md) ## Completed Execution Plans diff --git a/docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md b/docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md new file mode 100644 index 0000000000..015299a103 --- /dev/null +++ b/docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md @@ -0,0 +1,63 @@ +# Session-Aware Routing Convergence Loop + +## Goal + +- Retire the legacy `routing.session_states` / `SESSION_STATE` public surface instead of extending it further. +- Converge the repository on one production-oriented session-aware routing contract built around runtime-derived session facts, session-aware signals, and explicit selection policy wiring. +- Close the workstream only after config, runtime, replay/session identity, docs, and targeted validation all agree on the same steady-state behavior. + +## Scope + +- `docs/agent/plans/**` +- `src/semantic-router/pkg/config/**` +- `src/semantic-router/pkg/dsl/**` +- `src/semantic-router/pkg/extproc/**` +- `src/semantic-router/pkg/selection/**` +- `src/semantic-router/pkg/selection/lookuptable/**` +- `src/semantic-router/pkg/routerreplay/**` +- maintained config / recipe assets that exercise routing contracts +- targeted docs and validation paths for the touched surfaces +- nearest local rules for `pkg/config` and `pkg/extproc` + +## Exit Criteria + +- The repository no longer exposes `routing.session_states` or `SESSION_STATE` as supported steady-state config or DSL surface. +- Runtime-derived session facts are the only durable source of session-aware inputs consumed by routing logic. +- The config contract exposes one explicit session-aware control surface for decision-time routing, instead of spreading behavior across legacy schema leftovers and hidden selector heuristics. +- Session-aware routing can evaluate stay-versus-switch behavior using replay-backed lookup-table signals plus real session identity, not pseudo-session heuristics alone. +- Targeted docs, maintained assets, and changed-set validation reflect the steady-state contract and no longer teach the removed legacy surface. + +## Task List + +- [x] `SAR001` Create and index the durable execution plan for the session-aware routing workstream. +- [x] `SAR002` Remove the legacy `routing.session_states` / `SESSION_STATE` public surface from config, canonical export/import, DSL, maintained assets, and tests. +- [ ] `SAR003` Define the steady-state session-aware config contract around runtime-derived session facts, decision-time control knobs, and supported algorithm surfaces. +- [ ] `SAR004` Add the missing session-aware signal family and validator/catalog wiring so decisions can reference session / lookup signals explicitly. +- [ ] `SAR005` Implement the session-aware selector path that evaluates stay-versus-switch behavior using turn index, previous model, lookup-table priors, and fallback defaults. +- [ ] `SAR006` Replace pseudo-session replay grouping with a real session identity contract across replay ingestion, lookup-table derivation, and response-side telemetry. +- [ ] `SAR007` Add the production gates still missing for retention, shadow / audit output, and operator-visible debugging of session-aware decisions. +- [ ] `SAR008` Update maintained docs, examples, and targeted validation so the new session-aware contract is the only documented public path. +- [ ] `SAR009` Run the validation ladder for the changed surfaces, record results here, and add indexed debt only for gaps that remain after the loop closes. + +## Current Loop + +- Loop status: opened on 2026-04-15. +- Completed in this loop: + - confirmed the repository-native harness and plan requirements before editing + - removed `routing.session_states` from the config schema, canonical routing surface, DSL grammar/compiler/decompiler/validator chain, maintained recipe assets, and affected tests + - created this execution plan so the remaining multi-loop session-aware routing work is resumable from the repository alone +- Next loop focus: + - execute `SAR003` by defining the steady-state session-aware public contract without reintroducing legacy schema compatibility into runtime parsing + +## Decision Log + +- 2026-04-15: remove `routing.session_states` directly instead of keeping a compatibility bridge in the steady-state runtime contract. +- 2026-04-15: treat runtime-derived session facts plus explicit session-aware routing controls as the only forward path for multi-turn routing. +- 2026-04-15: keep lookup tables as reusable substrate, but do not treat the current pseudo-session replay heuristic as production-complete session identity. +- 2026-04-15: use one execution plan because the remaining work spans config, DSL, extproc, selection, replay, maintained assets, and validation loops. + +## Follow-up Debt / ADR Links + +- `issue #1753` stable multi-turn session-aware routing goal +- `docs/agent/lookup-tables.md` +- Add an ADR or debt entry only if the steady-state session-aware contract still diverges after this loop series completes. diff --git a/src/semantic-router/pkg/classification/classifier_signal_eval.go b/src/semantic-router/pkg/classification/classifier_signal_eval.go index bc0f6e20f6..1d787a7c2b 100644 --- a/src/semantic-router/pkg/classification/classifier_signal_eval.go +++ b/src/semantic-router/pkg/classification/classifier_signal_eval.go @@ -86,6 +86,7 @@ type SignalResults struct { MatchedStructureRules []string // Matched structure rule names (e.g. "many_questions") MatchedComplexityRules []string // Matched complexity rules with difficulty level (e.g. "code_complexity:hard") MatchedModalityRules []string // Matched modality: "AR", "DIFFUSION", or "BOTH" + MatchedSessionRules []string // Runtime-derived session signals injected after classifier evaluation MatchedAuthzRules []string // Matched authz role names for user-level RBAC routing MatchedJailbreakRules []string // Matched jailbreak rule names (confidence >= threshold) MatchedPIIRules []string // Matched PII rule names (denied PII types detected) @@ -299,11 +300,11 @@ func (c *Classifier) evaluateDecisionInternal(signals *SignalResults, trace bool return nil, nil, fmt.Errorf("no decisions configured") } - logging.Debugf("Signal evaluation results: keyword=%v, embedding=%v, domain=%v, fact_check=%v, user_feedback=%v, reask=%v, preference=%v, language=%v, context=%v, structure=%v, complexity=%v, modality=%v, authz=%v, jailbreak=%v, pii=%v, kb=%v", + logging.Debugf("Signal evaluation results: keyword=%v, embedding=%v, domain=%v, fact_check=%v, user_feedback=%v, reask=%v, preference=%v, language=%v, context=%v, structure=%v, complexity=%v, modality=%v, session=%v, authz=%v, jailbreak=%v, pii=%v, kb=%v", signals.MatchedKeywordRules, signals.MatchedEmbeddingRules, signals.MatchedDomainRules, signals.MatchedFactCheckRules, signals.MatchedUserFeedbackRules, signals.MatchedReaskRules, signals.MatchedPreferenceRules, signals.MatchedLanguageRules, signals.MatchedContextRules, signals.MatchedStructureRules, - signals.MatchedComplexityRules, signals.MatchedModalityRules, signals.MatchedAuthzRules, + signals.MatchedComplexityRules, signals.MatchedModalityRules, signals.MatchedSessionRules, signals.MatchedAuthzRules, signals.MatchedJailbreakRules, signals.MatchedPIIRules, signals.MatchedKBRules) engine := decision.NewDecisionEngine( @@ -327,6 +328,7 @@ func (c *Classifier) evaluateDecisionInternal(signals *SignalResults, trace bool StructureRules: signals.MatchedStructureRules, ComplexityRules: signals.MatchedComplexityRules, ModalityRules: signals.MatchedModalityRules, + SessionRules: signals.MatchedSessionRules, SignalConfidences: signals.SignalConfidences, AuthzRules: signals.MatchedAuthzRules, JailbreakRules: signals.MatchedJailbreakRules, diff --git a/src/semantic-router/pkg/config/canonical_config.go b/src/semantic-router/pkg/config/canonical_config.go index 3b192ccf8e..d94613d285 100644 --- a/src/semantic-router/pkg/config/canonical_config.go +++ b/src/semantic-router/pkg/config/canonical_config.go @@ -20,11 +20,10 @@ type CanonicalConfig struct { // CanonicalRouting contains the DSL-owned routing surface. type CanonicalRouting struct { - ModelCards []RoutingModel `yaml:"modelCards,omitempty"` - Signals CanonicalSignals `yaml:"signals,omitempty"` - Projections CanonicalProjections `yaml:"projections,omitempty"` - Decisions []Decision `yaml:"decisions,omitempty"` - SessionStates []SessionStateConfig `yaml:"session_states,omitempty"` + ModelCards []RoutingModel `yaml:"modelCards,omitempty"` + Signals CanonicalSignals `yaml:"signals,omitempty"` + Projections CanonicalProjections `yaml:"projections,omitempty"` + Decisions []Decision `yaml:"decisions,omitempty"` } // CanonicalSignals groups routing signals under routing.signals. @@ -41,6 +40,7 @@ type CanonicalSignals struct { Structure []StructureRule `yaml:"structure,omitempty"` Complexity []ComplexityRule `yaml:"complexity,omitempty"` Modality []ModalityRule `yaml:"modality,omitempty"` + Session []SessionRule `yaml:"session,omitempty"` RoleBindings []RoleBinding `yaml:"role_bindings,omitempty"` Jailbreak []JailbreakRule `yaml:"jailbreak,omitempty"` PII []PIIRule `yaml:"pii,omitempty"` @@ -106,7 +106,6 @@ func applyCanonicalRoutingState(cfg *RouterConfig, canonical *CanonicalConfig) { ensureModelRefDefaults(cfg.Decisions) cfg.Signals = normalizeSignals(canonical.Routing.Signals, cfg.Decisions) cfg.Projections = normalizeProjections(canonical.Routing.Projections) - cfg.SessionStates = append([]SessionStateConfig(nil), canonical.Routing.SessionStates...) cfg.ModelConfig = make(map[string]ModelParams) for _, model := range canonicalRoutingModels(canonical.Routing) { @@ -247,6 +246,7 @@ func normalizeSignals(signals CanonicalSignals, decisions []Decision) Signals { StructureRules: append([]StructureRule(nil), signals.Structure...), ComplexityRules: append([]ComplexityRule(nil), signals.Complexity...), ModalityRules: append([]ModalityRule(nil), signals.Modality...), + SessionRules: append([]SessionRule(nil), signals.Session...), RoleBindings: append([]RoleBinding(nil), signals.RoleBindings...), JailbreakRules: append([]JailbreakRule(nil), signals.Jailbreak...), PIIRules: append([]PIIRule(nil), signals.PII...), diff --git a/src/semantic-router/pkg/config/canonical_export.go b/src/semantic-router/pkg/config/canonical_export.go index 83960768a7..0213b4b5db 100644 --- a/src/semantic-router/pkg/config/canonical_export.go +++ b/src/semantic-router/pkg/config/canonical_export.go @@ -46,11 +46,10 @@ func CanonicalRoutingFromRouterConfig(cfg *RouterConfig) CanonicalRouting { } return CanonicalRouting{ - ModelCards: routingModelsFromRouterConfig(cfg), - Signals: canonicalSignalsFromRouterConfig(cfg), - Projections: canonicalProjectionsFromRouterConfig(cfg), - Decisions: copyDecisions(cfg.Decisions), - SessionStates: append([]SessionStateConfig(nil), cfg.SessionStates...), + ModelCards: routingModelsFromRouterConfig(cfg), + Signals: canonicalSignalsFromRouterConfig(cfg), + Projections: canonicalProjectionsFromRouterConfig(cfg), + Decisions: copyDecisions(cfg.Decisions), } } @@ -68,6 +67,7 @@ func canonicalSignalsFromRouterConfig(cfg *RouterConfig) CanonicalSignals { Structure: append([]StructureRule(nil), cfg.StructureRules...), Complexity: append([]ComplexityRule(nil), cfg.ComplexityRules...), Modality: append([]ModalityRule(nil), cfg.ModalityRules...), + Session: append([]SessionRule(nil), cfg.SessionRules...), RoleBindings: append([]RoleBinding(nil), cfg.RoleBindings...), Jailbreak: append([]JailbreakRule(nil), cfg.JailbreakRules...), PII: append([]PIIRule(nil), cfg.PIIRules...), diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 2df327600c..05795a5038 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -34,6 +34,7 @@ const ( SignalTypeStructure = "structure" SignalTypeComplexity = "complexity" SignalTypeModality = "modality" + SignalTypeSession = "session" SignalTypeAuthz = "authz" SignalTypeJailbreak = "jailbreak" SignalTypePII = "pii" @@ -183,11 +184,11 @@ type IntelligentRouting struct { Signals `yaml:",inline"` Projections Projections `yaml:"projections,omitempty"` Decisions []Decision `yaml:"decisions,omitempty"` - Strategy string `yaml:"strategy,omitempty"` - ModelSelection ModelSelectionConfig `yaml:"model_selection,omitempty"` - ReasoningConfig `yaml:",inline"` - SessionStates []SessionStateConfig `yaml:"session_states,omitempty"` -} + Strategy string `yaml:"strategy,omitempty"` + ModelSelection ModelSelectionConfig `yaml:"model_selection,omitempty"` + ReasoningConfig `yaml:",inline"` + } + // BackendModels captures configured backend endpoints and model metadata. type BackendModels struct { diff --git a/src/semantic-router/pkg/config/decision_config.go b/src/semantic-router/pkg/config/decision_config.go index b033dfbb44..e5e6484f59 100644 --- a/src/semantic-router/pkg/config/decision_config.go +++ b/src/semantic-router/pkg/config/decision_config.go @@ -22,6 +22,7 @@ type AlgorithmConfig struct { RouterDC *RouterDCSelectionConfig `yaml:"router_dc,omitempty"` AutoMix *AutoMixSelectionConfig `yaml:"automix,omitempty"` Hybrid *HybridSelectionConfig `yaml:"hybrid,omitempty"` + SessionAware *SessionAwareSelectionConfig `yaml:"session_aware,omitempty"` RLDriven *RLDrivenSelectionConfig `yaml:"rl_driven,omitempty"` GMTRouter *GMTRouterSelectionConfig `yaml:"gmtrouter,omitempty"` LatencyAware *LatencyAwareAlgorithmConfig `yaml:"latency_aware,omitempty"` diff --git a/src/semantic-router/pkg/config/fragment_catalog_test.go b/src/semantic-router/pkg/config/fragment_catalog_test.go index 4b6b2abe69..9f3e444ff4 100644 --- a/src/semantic-router/pkg/config/fragment_catalog_test.go +++ b/src/semantic-router/pkg/config/fragment_catalog_test.go @@ -31,6 +31,7 @@ func TestConfigFragmentCatalogCoversSupportedRoutingSurfaces(t *testing.T) { "elo": filepath.Join("selection", "elo.yaml"), "gmtrouter": filepath.Join("selection", "gmtrouter.yaml"), "hybrid": filepath.Join("selection", "hybrid.yaml"), + "session_aware": filepath.Join("selection", "session-aware.yaml"), "kmeans": filepath.Join("selection", "kmeans.yaml"), "knn": filepath.Join("selection", "knn.yaml"), "latency_aware": filepath.Join("selection", "latency-aware.yaml"), diff --git a/src/semantic-router/pkg/config/reference_config_public_surface_test.go b/src/semantic-router/pkg/config/reference_config_public_surface_test.go index 3d9af0c89a..b04ab5f3c9 100644 --- a/src/semantic-router/pkg/config/reference_config_public_surface_test.go +++ b/src/semantic-router/pkg/config/reference_config_public_surface_test.go @@ -51,17 +51,6 @@ func assertReferenceConfigRoutingCoverage(t testingT, root map[string]interface{ assertReferenceConfigSignalCoverage(t, mustMapAt(t, routing, "signals")) assertReferenceConfigProjectionCoverage(t, mustMapAt(t, routing, "projections")) assertReferenceConfigDecisionCoverage(t, mustSliceAt(t, routing, "decisions")) - assertReferenceConfigSessionStateCoverage(t, mustSliceAt(t, routing, "session_states")) -} - -func assertReferenceConfigSessionStateCoverage(t testingT, sessionStates []interface{}) { - assertSliceUnionCoversStructFields(t, sessionStates, reflect.TypeOf(SessionStateConfig{}), "routing.session_states") - assertSliceUnionCoversStructFields( - t, - collectNestedSliceItems(t, sessionStates, "fields", "routing.session_states"), - reflect.TypeOf(SessionStateFieldConfig{}), - "routing.session_states[].fields", - ) } func assertReferenceConfigSignalCoverage(t testingT, signals map[string]interface{}) { @@ -84,6 +73,7 @@ func assertReferenceConfigSignalCoverage(t testingT, signals map[string]interfac assertReferenceConfigStructureCoverage(t, mustSliceAt(t, signals, "structure")) assertReferenceConfigComplexityCoverage(t, mustSliceAt(t, signals, "complexity")) assertSliceUnionCoversStructFields(t, mustSliceAt(t, signals, "modality"), reflect.TypeOf(ModalityRule{}), "routing.signals.modality") + assertReferenceConfigSessionSignalCoverage(t, mustSliceAt(t, signals, "session")) assertReferenceConfigRoleBindingCoverage(t, mustSliceAt(t, signals, "role_bindings")) assertSliceUnionCoversStructFields(t, mustSliceAt(t, signals, "jailbreak"), reflect.TypeOf(JailbreakRule{}), "routing.signals.jailbreak") assertSliceUnionCoversStructFields(t, mustSliceAt(t, signals, "pii"), reflect.TypeOf(PIIRule{}), "routing.signals.pii") @@ -173,6 +163,16 @@ func assertReferenceConfigStructureCoverage(t testingT, structure []interface{}) ) } +func assertReferenceConfigSessionSignalCoverage(t testingT, session []interface{}) { + assertSliceUnionCoversStructFields(t, session, reflect.TypeOf(SessionRule{}), "routing.signals.session") + assertSliceUnionCoversStructFields( + t, + collectChildMapsFromSlice(t, session, "predicate", "routing.signals.session"), + reflect.TypeOf(NumericPredicate{}), + "routing.signals.session[].predicate", + ) +} + func assertReferenceConfigRoleBindingCoverage(t testingT, roleBindings []interface{}) { assertSliceUnionCoversStructFields(t, roleBindings, reflect.TypeOf(RoleBinding{}), "routing.signals.role_bindings") assertSliceUnionCoversStructFields( diff --git a/src/semantic-router/pkg/config/reference_config_routing_surface_test.go b/src/semantic-router/pkg/config/reference_config_routing_surface_test.go index 50dca3664d..33ac2034a3 100644 --- a/src/semantic-router/pkg/config/reference_config_routing_surface_test.go +++ b/src/semantic-router/pkg/config/reference_config_routing_surface_test.go @@ -13,6 +13,7 @@ var referenceSignalKeyByType = map[string]string{ SignalTypeKeyword: "keywords", SignalTypeLanguage: "language", SignalTypeModality: "modality", + SignalTypeSession: "session", SignalTypePII: "pii", SignalTypePreference: "preferences", SignalTypeReask: "reasks", @@ -48,6 +49,7 @@ func assertSupportedAlgorithmsInReferenceConfig(t testingT, decisions []interfac assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["router_dc"], "router_dc"), reflect.TypeOf(RouterDCSelectionConfig{}), "routing.decisions[].algorithm.router_dc") assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["automix"], "automix"), reflect.TypeOf(AutoMixSelectionConfig{}), "routing.decisions[].algorithm.automix") assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["hybrid"], "hybrid"), reflect.TypeOf(HybridSelectionConfig{}), "routing.decisions[].algorithm.hybrid") + assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["session_aware"], "session_aware"), reflect.TypeOf(SessionAwareSelectionConfig{}), "routing.decisions[].algorithm.session_aware") assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["rl_driven"], "rl_driven"), reflect.TypeOf(RLDrivenSelectionConfig{}), "routing.decisions[].algorithm.rl_driven") assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["gmtrouter"], "gmtrouter"), reflect.TypeOf(GMTRouterSelectionConfig{}), "routing.decisions[].algorithm.gmtrouter") assertMapCoversStructFields(t, mustMapAt(t, algorithmsByType["latency_aware"], "latency_aware"), reflect.TypeOf(LatencyAwareAlgorithmConfig{}), "routing.decisions[].algorithm.latency_aware") diff --git a/src/semantic-router/pkg/config/routing_surface_catalog.go b/src/semantic-router/pkg/config/routing_surface_catalog.go index 118cb73be6..2519af9776 100644 --- a/src/semantic-router/pkg/config/routing_surface_catalog.go +++ b/src/semantic-router/pkg/config/routing_surface_catalog.go @@ -27,6 +27,7 @@ var supportedSignalTypes = []string{ SignalTypeKeyword, SignalTypeLanguage, SignalTypeModality, + SignalTypeSession, SignalTypePII, SignalTypePreference, SignalTypeReask, @@ -62,6 +63,7 @@ var decisionAlgorithmCatalog = []AlgorithmCatalogEntry{ {Type: "elo", Tier: "supported"}, {Type: "gmtrouter", Tier: "experimental"}, {Type: "hybrid", Tier: "supported"}, + {Type: "session_aware", Tier: "supported"}, {Type: "kmeans", Tier: "experimental"}, {Type: "knn", Tier: "experimental"}, {Type: "latency_aware", Tier: "supported"}, diff --git a/src/semantic-router/pkg/config/selection_config.go b/src/semantic-router/pkg/config/selection_config.go index 85c3c2f771..65c3d4bfe3 100644 --- a/src/semantic-router/pkg/config/selection_config.go +++ b/src/semantic-router/pkg/config/selection_config.go @@ -9,11 +9,12 @@ type ModelSelectionConfig struct { Enabled bool `yaml:"enabled,omitempty"` // Family-specific configuration blocks. - Elo EloSelectionConfig `yaml:"elo,omitempty"` - RouterDC RouterDCSelectionConfig `yaml:"router_dc,omitempty"` - AutoMix AutoMixSelectionConfig `yaml:"automix,omitempty"` - Hybrid HybridSelectionConfig `yaml:"hybrid,omitempty"` - ML MLSelectionConfig `yaml:"ml,omitempty"` + Elo EloSelectionConfig `yaml:"elo,omitempty"` + RouterDC RouterDCSelectionConfig `yaml:"router_dc,omitempty"` + AutoMix AutoMixSelectionConfig `yaml:"automix,omitempty"` + Hybrid HybridSelectionConfig `yaml:"hybrid,omitempty"` + SessionAware SessionAwareSelectionConfig `yaml:"session_aware,omitempty"` + ML MLSelectionConfig `yaml:"ml,omitempty"` Momentum MomentumSelectionConfig `yaml:"momentum,omitempty"` // LookupTables configures persisted lookup tables for session-aware routing. diff --git a/src/semantic-router/pkg/config/session_selection_config.go b/src/semantic-router/pkg/config/session_selection_config.go new file mode 100644 index 0000000000..614f2af705 --- /dev/null +++ b/src/semantic-router/pkg/config/session_selection_config.go @@ -0,0 +1,26 @@ +package config + +// SessionAwareSelectionConfig configures stay-versus-switch routing for +// multi-turn sessions using runtime-derived session facts and lookup-table +// priors sourced from router replay. +type SessionAwareSelectionConfig struct { + // FallbackMethod is used when session context is unavailable or insufficient. + FallbackMethod string `yaml:"fallback_method,omitempty"` + + // MinTurnsBeforeSwitch suppresses switching early in a conversation. + MinTurnsBeforeSwitch int `yaml:"min_turns_before_switch,omitempty"` + + // StayBias adds a baseline preference for keeping the current session model. + StayBias float64 `yaml:"stay_bias,omitempty"` + + // QualityGapMultiplier scales lookup-table quality-gap estimates when + // evaluating a potential switch to another model. + QualityGapMultiplier float64 `yaml:"quality_gap_multiplier,omitempty"` + + // HandoffPenaltyWeight scales replay-derived switch penalties between models. + HandoffPenaltyWeight float64 `yaml:"handoff_penalty_weight,omitempty"` + + // RemainingTurnWeight increases the value of staying on the current model + // when the conversation is expected to continue for more turns. + RemainingTurnWeight float64 `yaml:"remaining_turn_weight,omitempty"` +} diff --git a/src/semantic-router/pkg/config/session_signal_config.go b/src/semantic-router/pkg/config/session_signal_config.go new file mode 100644 index 0000000000..054d383c12 --- /dev/null +++ b/src/semantic-router/pkg/config/session_signal_config.go @@ -0,0 +1,44 @@ +package config + +import "strings" + +const ( + SessionFactSessionPresent = "session_present" + SessionFactHasPreviousModel = "has_previous_model" + SessionFactTurnIndex = "turn_index" + SessionFactCacheWarmth = "cache_warmth" + SessionFactRemainingTurns = "remaining_turns" + SessionFactHandoffPenalty = "handoff_penalty" + SessionFactQualityGap = "quality_gap" +) + +// SessionRule declares a runtime-derived session signal that can be referenced +// directly from routing decisions and projection inputs. +type SessionRule struct { + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + Fact string `yaml:"fact"` + Predicate *NumericPredicate `yaml:"predicate,omitempty"` + IntentOrDomain string `yaml:"intent_or_domain,omitempty"` + PreviousModel string `yaml:"previous_model,omitempty"` + CandidateModel string `yaml:"candidate_model,omitempty"` +} + +func NormalizeSessionFact(fact string) string { + return strings.ToLower(strings.TrimSpace(fact)) +} + +func IsSupportedSessionFact(fact string) bool { + switch NormalizeSessionFact(fact) { + case SessionFactSessionPresent, + SessionFactHasPreviousModel, + SessionFactTurnIndex, + SessionFactCacheWarmth, + SessionFactRemainingTurns, + SessionFactHandoffPenalty, + SessionFactQualityGap: + return true + default: + return false + } +} diff --git a/src/semantic-router/pkg/config/session_state_config.go b/src/semantic-router/pkg/config/session_state_config.go deleted file mode 100644 index 1da6cc912b..0000000000 --- a/src/semantic-router/pkg/config/session_state_config.go +++ /dev/null @@ -1,14 +0,0 @@ -package config - -// SessionStateFieldConfig is one typed field inside a SESSION_STATE declaration. -type SessionStateFieldConfig struct { - Name string `yaml:"name"` - TypeName string `yaml:"type"` -} - -// SessionStateConfig represents a SESSION_STATE declaration, naming the -// cross-turn fields that session-aware routing policies can reference. -type SessionStateConfig struct { - Name string `yaml:"name"` - Fields []SessionStateFieldConfig `yaml:"fields,omitempty"` -} diff --git a/src/semantic-router/pkg/config/signal_config.go b/src/semantic-router/pkg/config/signal_config.go index 76b8bc92d9..bcaabf99c5 100644 --- a/src/semantic-router/pkg/config/signal_config.go +++ b/src/semantic-router/pkg/config/signal_config.go @@ -19,6 +19,7 @@ type Signals struct { StructureRules []StructureRule `yaml:"structure_rules,omitempty"` ComplexityRules []ComplexityRule `yaml:"complexity_rules,omitempty"` ModalityRules []ModalityRule `yaml:"modality_rules,omitempty"` + SessionRules []SessionRule `yaml:"session,omitempty"` RoleBindings []RoleBinding `yaml:"role_bindings,omitempty"` JailbreakRules []JailbreakRule `yaml:"jailbreak,omitempty"` PIIRules []PIIRule `yaml:"pii,omitempty"` diff --git a/src/semantic-router/pkg/config/validator.go b/src/semantic-router/pkg/config/validator.go index e8685e74ce..f207a89056 100644 --- a/src/semantic-router/pkg/config/validator.go +++ b/src/semantic-router/pkg/config/validator.go @@ -106,6 +106,9 @@ func validateConfigStructure(cfg *RouterConfig) error { if err := validateReaskContracts(cfg); err != nil { return err } + if err := validateSessionContracts(cfg); err != nil { + return err + } if err := validateProjectionContracts(cfg); err != nil { return err } diff --git a/src/semantic-router/pkg/config/validator_decision.go b/src/semantic-router/pkg/config/validator_decision.go index 682751a133..9d135296f3 100644 --- a/src/semantic-router/pkg/config/validator_decision.go +++ b/src/semantic-router/pkg/config/validator_decision.go @@ -139,6 +139,7 @@ func validateDecisionAlgorithmConfig(decisionName string, algorithm *AlgorithmCo addBlock("router_dc", algorithm.RouterDC != nil) addBlock("automix", algorithm.AutoMix != nil) addBlock("hybrid", algorithm.Hybrid != nil) + addBlock("session_aware", algorithm.SessionAware != nil) addBlock("rl_driven", algorithm.RLDriven != nil) addBlock("gmtrouter", algorithm.GMTRouter != nil) addBlock("latency_aware", algorithm.LatencyAware != nil) @@ -160,6 +161,7 @@ func validateDecisionAlgorithmConfig(decisionName string, algorithm *AlgorithmCo "router_dc": "router_dc", "automix": "automix", "hybrid": "hybrid", + "session_aware": "session_aware", "rl_driven": "rl_driven", "gmtrouter": "gmtrouter", "latency_aware": "latency_aware", @@ -188,6 +190,15 @@ func validateDecisionAlgorithmConfig(decisionName string, algorithm *AlgorithmCo ) } + if normalizedType == "session_aware" { + if algorithm.SessionAware == nil { + return fmt.Errorf("decision '%s': algorithm.type=session_aware requires algorithm.session_aware configuration", decisionName) + } + if err := validateSessionAwareAlgorithmConfig(algorithm.SessionAware); err != nil { + return fmt.Errorf("decision '%s', algorithm.session_aware: %w", decisionName, err) + } + } + if normalizedType == "latency_aware" { if algorithm.LatencyAware == nil { return fmt.Errorf("decision '%s': algorithm.type=latency_aware requires algorithm.latency_aware configuration", decisionName) diff --git a/src/semantic-router/pkg/config/validator_projection.go b/src/semantic-router/pkg/config/validator_projection.go index 3d67a2939d..c432763335 100644 --- a/src/semantic-router/pkg/config/validator_projection.go +++ b/src/semantic-router/pkg/config/validator_projection.go @@ -238,6 +238,7 @@ func isProjectionInputTypeSupported(signalType string) bool { SignalTypeStructure, SignalTypeComplexity, SignalTypeModality, + SignalTypeSession, SignalTypeAuthz, SignalTypeJailbreak, SignalTypePII, @@ -263,6 +264,7 @@ func projectionDeclaredSignals(cfg *RouterConfig) map[string]map[string]struct{} SignalTypeStructure: collectStructureRuleNames(cfg.StructureRules), SignalTypeComplexity: collectComplexityRuleNames(cfg.ComplexityRules), SignalTypeModality: collectModalityRuleNames(cfg.ModalityRules), + SignalTypeSession: collectSessionRuleNames(cfg.SessionRules), SignalTypeAuthz: collectRoleBindingNames(cfg.GetRoleBindings()), SignalTypeJailbreak: collectJailbreakRuleNames(cfg.JailbreakRules), SignalTypePII: collectPIIRuleNames(cfg.PIIRules), @@ -480,6 +482,14 @@ func collectModalityRuleNames(rules []ModalityRule) map[string]struct{} { return names } +func collectSessionRuleNames(rules []SessionRule) map[string]struct{} { + names := make(map[string]struct{}, len(rules)) + for _, rule := range rules { + names[rule.Name] = struct{}{} + } + return names +} + func collectRoleBindingNames(rules []RoleBinding) map[string]struct{} { names := make(map[string]struct{}, len(rules)*2) for _, rule := range rules { diff --git a/src/semantic-router/pkg/config/validator_projection_test.go b/src/semantic-router/pkg/config/validator_projection_test.go index 540a814288..e57dcc8e7e 100644 --- a/src/semantic-router/pkg/config/validator_projection_test.go +++ b/src/semantic-router/pkg/config/validator_projection_test.go @@ -94,3 +94,45 @@ routing: t.Fatalf("unexpected error: %v", err) } } + +func TestParseRoutingYAMLBytesAcceptsSessionProjectionScoreInput(t *testing.T) { + yaml := []byte(` +routing: + signals: + session: + - name: session_present + fact: session_present + predicate: + gte: 1 + projections: + scores: + - name: continuity_score + method: weighted_sum + inputs: + - type: session + name: session_present + weight: 0.6 + mappings: + - name: continuity_band + source: continuity_score + method: threshold_bands + outputs: + - name: continue_session + gte: 0.5 + decisions: + - name: session_route + rules: + operator: AND + conditions: + - type: projection + name: continue_session + modelRefs: + - model: qwen3-8b +`) + + _, err := ParseRoutingYAMLBytes(yaml) + if err != nil { + fatalf := t.Fatalf + fatalf("expected session projection input to validate, got: %v", err) + } +} diff --git a/src/semantic-router/pkg/config/validator_session.go b/src/semantic-router/pkg/config/validator_session.go new file mode 100644 index 0000000000..ee3c408c55 --- /dev/null +++ b/src/semantic-router/pkg/config/validator_session.go @@ -0,0 +1,83 @@ +package config + +import ( + "fmt" + "strings" +) + +func validateSessionContracts(cfg *RouterConfig) error { + for i, rule := range cfg.SessionRules { + if err := validateSessionRule(rule); err != nil { + return fmt.Errorf("routing.signals.session[%d]: %w", i, err) + } + } + return nil +} + +func validateSessionRule(rule SessionRule) error { + if strings.TrimSpace(rule.Name) == "" { + return fmt.Errorf("name cannot be empty") + } + if !IsSupportedSessionFact(rule.Fact) { + return fmt.Errorf("fact %q is unsupported", rule.Fact) + } + if rule.Predicate == nil { + return fmt.Errorf("predicate is required") + } + if err := validateNumericPredicate(rule.Predicate); err != nil { + return fmt.Errorf("predicate: %w", err) + } + + switch NormalizeSessionFact(rule.Fact) { + case SessionFactQualityGap, SessionFactHandoffPenalty: + if strings.TrimSpace(rule.CandidateModel) == "" { + return fmt.Errorf("candidate_model is required for fact %q", rule.Fact) + } + } + return nil +} + +func validateSessionAwareAlgorithmConfig(cfg *SessionAwareSelectionConfig) error { + if cfg == nil { + return fmt.Errorf("configuration cannot be nil") + } + if cfg.MinTurnsBeforeSwitch < 0 { + return fmt.Errorf("min_turns_before_switch must be >= 0") + } + for _, field := range []struct { + name string + value float64 + }{ + {name: "stay_bias", value: cfg.StayBias}, + {name: "quality_gap_multiplier", value: cfg.QualityGapMultiplier}, + {name: "handoff_penalty_weight", value: cfg.HandoffPenaltyWeight}, + {name: "remaining_turn_weight", value: cfg.RemainingTurnWeight}, + } { + if field.value < 0 { + return fmt.Errorf("%s must be >= 0", field.name) + } + } + if cfg.FallbackMethod != "" && !IsSupportedDecisionAlgorithmType(cfg.FallbackMethod) { + return fmt.Errorf("fallback_method %q is unsupported", cfg.FallbackMethod) + } + if strings.EqualFold(cfg.FallbackMethod, "session_aware") { + return fmt.Errorf("fallback_method cannot be session_aware") + } + return nil +} + +func validateNumericPredicate(predicate *NumericPredicate) error { + if predicate == nil { + return fmt.Errorf("cannot be nil") + } + if predicate.GT == nil && predicate.GTE == nil && predicate.LT == nil && predicate.LTE == nil { + return fmt.Errorf("at least one of gt, gte, lt, lte is required") + } + if predicate.GT != nil && predicate.GTE != nil { + return fmt.Errorf("cannot set both gt and gte") + } + if predicate.LT != nil && predicate.LTE != nil { + return fmt.Errorf("cannot set both lt and lte") + } + return nil +} diff --git a/src/semantic-router/pkg/decision/engine.go b/src/semantic-router/pkg/decision/engine.go index 1d0d054be5..c910db4b11 100644 --- a/src/semantic-router/pkg/decision/engine.go +++ b/src/semantic-router/pkg/decision/engine.go @@ -71,6 +71,7 @@ type SignalMatches struct { StructureRules []string // Structure rule names matched (e.g. "many_questions") ComplexityRules []string // Complexity rules with difficulty level (e.g. "code_complexity:hard") ModalityRules []string // Modality classification: "AR", "DIFFUSION", or "BOTH" + SessionRules []string // Runtime-derived session signals (e.g. warm cache / continuation risk) AuthzRules []string // Authz rule names matched for user-level routing (e.g. "premium_tier") JailbreakRules []string // Jailbreak rule names matched (confidence >= threshold) PIIRules []string // PII rule names matched (denied PII types detected) diff --git a/src/semantic-router/pkg/dsl/ast.go b/src/semantic-router/pkg/dsl/ast.go index 8d2f2ea91a..17ef992f52 100644 --- a/src/semantic-router/pkg/dsl/ast.go +++ b/src/semantic-router/pkg/dsl/ast.go @@ -39,7 +39,6 @@ type rawTopLevel struct { Model *rawModelDecl `parser:"| @@"` Plugin *rawPluginDecl `parser:"| @@"` TestBlock *rawTestBlockDecl `parser:"| @@"` - SessionState *rawSessionStateDecl `parser:"| @@"` } // rawTestBlockDecl: TEST { entries... } @@ -56,13 +55,6 @@ type rawTestEntryDecl struct { RouteName string `parser:"Arrow @(Ident | String)"` } -// rawSessionStateDecl: SESSION_STATE { fields... } -type rawSessionStateDecl struct { - Pos lexer.Position - Name string `parser:"'SESSION_STATE' @(Ident | String)"` - Fields []*FieldEntry `parser:"'{' @@* '}'"` -} - // rawSignalDecl: SIGNAL { fields... } type rawSignalDecl struct { Pos lexer.Position @@ -255,20 +247,6 @@ type Program struct { Models []*ModelDecl Plugins []*PluginDecl TestBlocks []*TestBlockDecl - SessionStates []*SessionStateDecl -} - -// SessionStateField is one typed field in a SessionStateDecl. -type SessionStateField struct { - Name string - TypeName string // "int", "string", or "float" -} - -// SessionStateDecl represents a SESSION_STATE top-level declaration. -type SessionStateDecl struct { - Name string - Fields []SessionStateField - Pos Position } // ProjectionPartitionDecl declares a mutually exclusive partition of signals. diff --git a/src/semantic-router/pkg/dsl/compiler.go b/src/semantic-router/pkg/dsl/compiler.go index ea64679c88..25cd9d5fe6 100644 --- a/src/semantic-router/pkg/dsl/compiler.go +++ b/src/semantic-router/pkg/dsl/compiler.go @@ -59,26 +59,10 @@ func (c *Compiler) compile() { // 5. Compile top-level model catalog c.compileModels() - // 6. Compile session state declarations - c.compileSessionStates() - - // 7. Compile routes (decisions) + // 6. Compile routes (decisions) c.compileRoutes() } -func (c *Compiler) compileSessionStates() { - for _, decl := range c.prog.SessionStates { - ss := config.SessionStateConfig{Name: decl.Name} - for _, f := range decl.Fields { - ss.Fields = append(ss.Fields, config.SessionStateFieldConfig{ - Name: f.Name, - TypeName: f.TypeName, - }) - } - c.config.SessionStates = append(c.config.SessionStates, ss) - } -} - func (c *Compiler) compileProjectionPartitions() { for _, partitionDecl := range c.prog.ProjectionPartitions { c.validateSoftmaxDomainProjectionPartition(partitionDecl) @@ -172,6 +156,8 @@ func (c *Compiler) compileSignals() { c.compileComplexitySignal(s) case "modality": c.compileModalitySignal(s) + case "session": + c.compileSessionSignal(s) case "authz": c.compileAuthzSignal(s) case "jailbreak": @@ -386,6 +372,24 @@ func (c *Compiler) compileModalitySignal(s *SignalDecl) { c.config.ModalityRules = append(c.config.ModalityRules, rule) } +func (c *Compiler) compileSessionSignal(s *SignalDecl) { + payload := fieldsToMap(s.Fields) + payload["name"] = s.Name + + raw, err := yaml.Marshal(payload) + if err != nil { + c.addError(s.Pos, "failed to encode session signal %q: %v", s.Name, err) + return + } + + var rule config.SessionRule + if err := yaml.Unmarshal(raw, &rule); err != nil { + c.addError(s.Pos, "failed to decode session signal %q: %v", s.Name, err) + return + } + c.config.SessionRules = append(c.config.SessionRules, rule) +} + func (c *Compiler) compileJailbreakSignal(s *SignalDecl) { rule := config.JailbreakRule{Name: s.Name} if v, ok := getStringField(s.Fields, "method"); ok { @@ -636,6 +640,8 @@ func (c *Compiler) compileAlgorithm(spec *AlgoSpec) *config.AlgorithmConfig { algo.AutoMix = c.compileAutoMixAlgo(spec.Fields) case "hybrid": algo.Hybrid = c.compileHybridAlgo(spec.Fields) + case "session_aware": + algo.SessionAware = c.compileSessionAwareAlgo(spec.Fields) case "rl_driven": algo.RLDriven = c.compileRLDrivenAlgo(spec.Fields) case "gmtrouter": @@ -819,6 +825,29 @@ func (c *Compiler) compileHybridAlgo(fields map[string]Value) *config.HybridSele return cfg } +func (c *Compiler) compileSessionAwareAlgo(fields map[string]Value) *config.SessionAwareSelectionConfig { + cfg := &config.SessionAwareSelectionConfig{} + if v, ok := getStringField(fields, "fallback_method"); ok { + cfg.FallbackMethod = v + } + if v, ok := getIntField(fields, "min_turns_before_switch"); ok { + cfg.MinTurnsBeforeSwitch = v + } + if v, ok := getFloat64Field(fields, "stay_bias"); ok { + cfg.StayBias = v + } + if v, ok := getFloat64Field(fields, "quality_gap_multiplier"); ok { + cfg.QualityGapMultiplier = v + } + if v, ok := getFloat64Field(fields, "handoff_penalty_weight"); ok { + cfg.HandoffPenaltyWeight = v + } + if v, ok := getFloat64Field(fields, "remaining_turn_weight"); ok { + cfg.RemainingTurnWeight = v + } + return cfg +} + func (c *Compiler) compileRLDrivenAlgo(fields map[string]Value) *config.RLDrivenSelectionConfig { cfg := &config.RLDrivenSelectionConfig{} if v, ok := getFloat64Field(fields, "exploration_rate"); ok { diff --git a/src/semantic-router/pkg/dsl/decompiler.go b/src/semantic-router/pkg/dsl/decompiler.go index 2fe568dc4f..336f95bd21 100644 --- a/src/semantic-router/pkg/dsl/decompiler.go +++ b/src/semantic-router/pkg/dsl/decompiler.go @@ -32,18 +32,6 @@ type pluginTemplate struct { usageCount int } -// ---------- Session State Decompilation ---------- - -func (d *decompiler) decompileSessionStates() { - for _, ss := range d.cfg.SessionStates { - d.write("SESSION_STATE %s {\n", quoteName(ss.Name)) - for _, f := range ss.Fields { - d.write(" %s: %s\n", f.Name, f.TypeName) - } - d.write("}\n\n") - } -} - // ---------- Signal Decompilation ---------- func (d *decompiler) decompileSignals() { @@ -207,6 +195,44 @@ func (d *decompiler) decompileSignals() { d.write("}\n\n") } + for _, session := range d.cfg.SessionRules { + d.write("SIGNAL session %s {\n", quoteName(session.Name)) + if session.Description != "" { + d.write(" description: %q\n", session.Description) + } + if session.Fact != "" { + d.write(" fact: %q\n", session.Fact) + } + if session.IntentOrDomain != "" { + d.write(" intent_or_domain: %q\n", session.IntentOrDomain) + } + if session.PreviousModel != "" { + d.write(" previous_model: %q\n", session.PreviousModel) + } + if session.CandidateModel != "" { + d.write(" candidate_model: %q\n", session.CandidateModel) + } + if session.Predicate != nil { + parts := make([]string, 0, 4) + if session.Predicate.GT != nil { + parts = append(parts, fmt.Sprintf("gt: %g", *session.Predicate.GT)) + } + if session.Predicate.GTE != nil { + parts = append(parts, fmt.Sprintf("gte: %g", *session.Predicate.GTE)) + } + if session.Predicate.LT != nil { + parts = append(parts, fmt.Sprintf("lt: %g", *session.Predicate.LT)) + } + if session.Predicate.LTE != nil { + parts = append(parts, fmt.Sprintf("lte: %g", *session.Predicate.LTE)) + } + if len(parts) > 0 { + d.write(" predicate: { %s }\n", strings.Join(parts, ", ")) + } + } + d.write("}\n\n") + } + for _, rb := range d.cfg.RoleBindings { d.write("SIGNAL authz %s {\n", quoteName(rb.Name)) if rb.Role != "" { @@ -989,6 +1015,27 @@ func (d *decompiler) decompileAlgorithmFields(algo *config.AlgorithmConfig) stri fmt.Fprintf(&sb, " max_escalations: %d\n", a.MaxEscalations) } } + case "session_aware": + if s := algo.SessionAware; s != nil { + if s.FallbackMethod != "" { + fmt.Fprintf(&sb, " fallback_method: %q\n", s.FallbackMethod) + } + if s.MinTurnsBeforeSwitch != 0 { + fmt.Fprintf(&sb, " min_turns_before_switch: %d\n", s.MinTurnsBeforeSwitch) + } + if s.StayBias != 0 { + fmt.Fprintf(&sb, " stay_bias: %v\n", s.StayBias) + } + if s.QualityGapMultiplier != 0 { + fmt.Fprintf(&sb, " quality_gap_multiplier: %v\n", s.QualityGapMultiplier) + } + if s.HandoffPenaltyWeight != 0 { + fmt.Fprintf(&sb, " handoff_penalty_weight: %v\n", s.HandoffPenaltyWeight) + } + if s.RemainingTurnWeight != 0 { + fmt.Fprintf(&sb, " remaining_turn_weight: %v\n", s.RemainingTurnWeight) + } + } case "latency_aware": if l := algo.LatencyAware; l != nil { if l.TPOTPercentile != 0 { @@ -1151,6 +1198,44 @@ func (d *decompiler) modalityToSignal(mod *config.ModalityRule) *SignalDecl { return &SignalDecl{SignalType: "modality", Name: mod.Name, Fields: fields} } +func (d *decompiler) sessionToSignal(rule *config.SessionRule) *SignalDecl { + fields := make(map[string]Value) + if rule.Description != "" { + fields["description"] = StringValue{V: rule.Description} + } + if rule.Fact != "" { + fields["fact"] = StringValue{V: rule.Fact} + } + if rule.IntentOrDomain != "" { + fields["intent_or_domain"] = StringValue{V: rule.IntentOrDomain} + } + if rule.PreviousModel != "" { + fields["previous_model"] = StringValue{V: rule.PreviousModel} + } + if rule.CandidateModel != "" { + fields["candidate_model"] = StringValue{V: rule.CandidateModel} + } + if rule.Predicate != nil { + predicateFields := make(map[string]Value) + if rule.Predicate.GT != nil { + predicateFields["gt"] = FloatValue{V: *rule.Predicate.GT} + } + if rule.Predicate.GTE != nil { + predicateFields["gte"] = FloatValue{V: *rule.Predicate.GTE} + } + if rule.Predicate.LT != nil { + predicateFields["lt"] = FloatValue{V: *rule.Predicate.LT} + } + if rule.Predicate.LTE != nil { + predicateFields["lte"] = FloatValue{V: *rule.Predicate.LTE} + } + if len(predicateFields) > 0 { + fields["predicate"] = ObjectValue{Fields: predicateFields} + } + } + return &SignalDecl{SignalType: "session", Name: rule.Name, Fields: fields} +} + func (d *decompiler) roleBindingToSignal(rb *config.RoleBinding) *SignalDecl { fields := make(map[string]Value) if rb.Role != "" { @@ -1562,6 +1647,27 @@ func (d *decompiler) algorithmToFields(algo *config.AlgorithmConfig) map[string] fields["ttft_percentile"] = IntValue{V: l.TTFTPercentile} } } + case "session_aware": + if s := algo.SessionAware; s != nil { + if s.FallbackMethod != "" { + fields["fallback_method"] = StringValue{V: s.FallbackMethod} + } + if s.MinTurnsBeforeSwitch != 0 { + fields["min_turns_before_switch"] = IntValue{V: s.MinTurnsBeforeSwitch} + } + if s.StayBias != 0 { + fields["stay_bias"] = FloatValue{V: s.StayBias} + } + if s.QualityGapMultiplier != 0 { + fields["quality_gap_multiplier"] = FloatValue{V: s.QualityGapMultiplier} + } + if s.HandoffPenaltyWeight != 0 { + fields["handoff_penalty_weight"] = FloatValue{V: s.HandoffPenaltyWeight} + } + if s.RemainingTurnWeight != 0 { + fields["remaining_turn_weight"] = FloatValue{V: s.RemainingTurnWeight} + } + } case "confidence": if c := algo.Confidence; c != nil { if c.ConfidenceMethod != "" { diff --git a/src/semantic-router/pkg/dsl/dsl_test.go b/src/semantic-router/pkg/dsl/dsl_test.go index db8194b9e4..9fe5fd5466 100644 --- a/src/semantic-router/pkg/dsl/dsl_test.go +++ b/src/semantic-router/pkg/dsl/dsl_test.go @@ -1483,6 +1483,25 @@ func TestCompileAllAlgorithmTypes(t *testing.T) { } }, }, + { + name: "session_aware", + algoType: "session_aware", + body: `fallback_method: "static" min_turns_before_switch: 2 stay_bias: 0.4 quality_gap_multiplier: 1.2 handoff_penalty_weight: 0.8 remaining_turn_weight: 0.6`, + verify: func(t *testing.T, algo *config.AlgorithmConfig) { + if algo.SessionAware == nil { + t.Fatal("expected session_aware config") + } + if algo.SessionAware.MinTurnsBeforeSwitch != 2 { + t.Errorf("min_turns_before_switch = %d, want 2", algo.SessionAware.MinTurnsBeforeSwitch) + } + if algo.SessionAware.FallbackMethod != "static" { + t.Errorf("fallback_method = %q, want static", algo.SessionAware.FallbackMethod) + } + if algo.SessionAware.StayBias != 0.4 { + t.Errorf("stay_bias = %v, want 0.4", algo.SessionAware.StayBias) + } + }, + }, { name: "rl_driven", algoType: "rl_driven", @@ -3265,6 +3284,48 @@ ROUTE test { } } +func TestDecompileSessionAwareAlgorithmFields(t *testing.T) { + input := ` +SIGNAL domain test { description: "test" } +ROUTE test { + PRIORITY 1 + WHEN domain("test") + MODEL "m1:7b", "m2:3b" + ALGORITHM session_aware { + fallback_method: "static" + min_turns_before_switch: 2 + stay_bias: 0.4 + quality_gap_multiplier: 1.2 + handoff_penalty_weight: 0.8 + remaining_turn_weight: 0.6 + } +} +` + cfg, errs := Compile(input) + if len(errs) > 0 { + t.Fatalf("compile errors: %v", errs) + } + + dslText, err := Decompile(cfg) + if err != nil { + t.Fatalf("decompile error: %v", err) + } + + for _, expected := range []string{ + "ALGORITHM session_aware", + "fallback_method", + "min_turns_before_switch", + "stay_bias", + "quality_gap_multiplier", + "handoff_penalty_weight", + "remaining_turn_weight", + } { + if !strings.Contains(dslText, expected) { + t.Errorf("missing %q in decompiled output", expected) + } + } +} + func TestDecompileToAST(t *testing.T) { cfg, errs := Compile(fullDSLExample) if len(errs) > 0 { @@ -5562,153 +5623,3 @@ func assertConflictFreeRoundTrip(t *testing.T, cfg *config.RouterConfig) { t.Errorf("re-parsed projection partitions = %d, want 1", len(prog2.ProjectionPartitions)) } } - -// ---------- SESSION_STATE Tests ---------- - -const sessionStateDSL = `SESSION_STATE session_routing { - turn_number: int - current_model: string - cumulative_cost_usd: float - retry_count_ema: float - quality_score_ema: float - kv_cache_warm: float -}` - -func TestCompileSessionState(t *testing.T) { - cfg, errs := Compile(sessionStateDSL) - if len(errs) > 0 { - t.Fatalf("compile errors: %v", errs) - } - if len(cfg.SessionStates) != 1 { - t.Fatalf("expected 1 SessionState in config, got %d", len(cfg.SessionStates)) - } - ss := cfg.SessionStates[0] - if ss.Name != "session_routing" { - t.Errorf("name: expected %q, got %q", "session_routing", ss.Name) - } - wantFields := []struct{ name, typeName string }{ - {"turn_number", "int"}, - {"current_model", "string"}, - {"cumulative_cost_usd", "float"}, - {"retry_count_ema", "float"}, - {"quality_score_ema", "float"}, - {"kv_cache_warm", "float"}, - } - if len(ss.Fields) != len(wantFields) { - t.Fatalf("expected %d fields, got %d", len(wantFields), len(ss.Fields)) - } - for i, w := range wantFields { - if ss.Fields[i].Name != w.name || ss.Fields[i].TypeName != w.typeName { - t.Errorf("field[%d]: expected {%s: %s}, got {%s: %s}", - i, w.name, w.typeName, ss.Fields[i].Name, ss.Fields[i].TypeName) - } - } -} - -func TestSessionStateRoundTrip(t *testing.T) { - cfg, errs := Compile(sessionStateDSL) - if len(errs) > 0 { - t.Fatalf("compile errors: %v", errs) - } - dslText, err := DecompileRouting(cfg) - if err != nil { - t.Fatalf("decompile error: %v", err) - } - if !strings.Contains(dslText, "SESSION_STATE session_routing") { - t.Errorf("round-trip lost SESSION_STATE declaration\nDSL:\n%s", dslText) - } - // Types must survive as bare identifiers, not quoted strings. - for _, typeName := range []string{"int", "string", "float"} { - if !strings.Contains(dslText, ": "+typeName) { - t.Errorf("round-trip DSL missing bare type %q\nDSL:\n%s", typeName, dslText) - } - } - prog2, errs2 := Parse(dslText) - if len(errs2) > 0 { - t.Fatalf("re-parse errors after round-trip: %v\nDSL:\n%s", errs2, dslText) - } - if len(prog2.SessionStates) != 1 { - t.Fatalf("re-parsed session states = %d, want 1\nDSL:\n%s", len(prog2.SessionStates), dslText) - } - if prog2.SessionStates[0].Name != "session_routing" { - t.Errorf("round-trip name: expected %q, got %q", "session_routing", prog2.SessionStates[0].Name) - } -} - -func TestSessionStateRoundTripAST(t *testing.T) { - cfg, errs := Compile(sessionStateDSL) - if len(errs) > 0 { - t.Fatalf("compile errors: %v", errs) - } - prog := DecompileToAST(cfg) - if len(prog.SessionStates) != 1 { - t.Fatalf("expected 1 SessionState in AST round-trip, got %d", len(prog.SessionStates)) - } - ss := prog.SessionStates[0] - if ss.Name != "session_routing" { - t.Errorf("AST round-trip name: expected %q, got %q", "session_routing", ss.Name) - } - if len(ss.Fields) != 6 { - t.Errorf("AST round-trip field count: expected 6, got %d", len(ss.Fields)) - } -} - -func TestValidateSessionStateDuplicateName(t *testing.T) { - input := ` -SESSION_STATE foo { x: int } -SESSION_STATE foo { y: string } -` - diags, _ := Validate(input) - found := false - for _, d := range diags { - if d.Level == DiagConstraint && strings.Contains(d.Message, "duplicate") { - found = true - break - } - } - if !found { - t.Errorf("expected DiagConstraint for duplicate SESSION_STATE name, got: %v", diags) - } -} - -func TestValidateSessionStateInvalidType(t *testing.T) { - input := `SESSION_STATE foo { x: bool }` - diags, _ := Validate(input) - found := false - for _, d := range diags { - if d.Level == DiagConstraint && - strings.Contains(d.Message, "invalid type") && - strings.Contains(d.Message, "bool") { - found = true - break - } - } - if !found { - t.Errorf("expected DiagConstraint for invalid type %q, got: %v", "bool", diags) - } -} - -func TestValidateSessionStateDuplicateField(t *testing.T) { - input := `SESSION_STATE foo { x: int, x: string }` - diags, _ := Validate(input) - found := false - for _, d := range diags { - if d.Level == DiagConstraint && strings.Contains(d.Message, "duplicate field") { - found = true - break - } - } - if !found { - t.Errorf("expected DiagConstraint for duplicate field name, got: %v", diags) - } -} - -func TestValidateSessionStateValidTypes(t *testing.T) { - input := `SESSION_STATE ok { a: int, b: string, c: float }` - diags, _ := Validate(input) - for _, d := range diags { - if d.Level == DiagConstraint { - t.Errorf("unexpected constraint diagnostic for valid types: %v", d) - } - } -} diff --git a/src/semantic-router/pkg/dsl/maintained_asset_roundtrip_test.go b/src/semantic-router/pkg/dsl/maintained_asset_roundtrip_test.go index d916f85a2f..ca2de93c43 100644 --- a/src/semantic-router/pkg/dsl/maintained_asset_roundtrip_test.go +++ b/src/semantic-router/pkg/dsl/maintained_asset_roundtrip_test.go @@ -374,29 +374,6 @@ func ruleTreeContainsNegatedSignal(node *config.RuleCombination, signalType stri return ruleTreeContainsSignal(node, false, signalType, name) } -func TestMaintainedSessionStateRecipeAssetsStayInSync(t *testing.T) { - dslPath := filepath.Join("..", "..", "..", "..", "deploy", "recipes", "session-state.dsl") - yamlPath := filepath.Join("..", "..", "..", "..", "deploy", "recipes", "session-state.yaml") - - dslData, err := os.ReadFile(dslPath) - if err != nil { - t.Fatalf("failed to read session-state.dsl: %v", err) - } - want, errs := Compile(string(dslData)) - if len(errs) > 0 { - t.Fatalf("Compile errors for session-state.dsl: %v", errs) - } - - got, err := config.Parse(yamlPath) - if err != nil { - t.Fatalf("failed to parse session-state.yaml: %v", err) - } - - if !reflect.DeepEqual(got.SessionStates, want.SessionStates) { - t.Fatalf("session-state DSL/YAML assets diverged\nwant: %+v\ngot: %+v", want.SessionStates, got.SessionStates) - } -} - func ruleTreeContainsSignal(node *config.RuleCombination, negated bool, signalType string, name string) bool { if node == nil { return false diff --git a/src/semantic-router/pkg/dsl/parser.go b/src/semantic-router/pkg/dsl/parser.go index d3e164f4a1..49ec9ec853 100644 --- a/src/semantic-router/pkg/dsl/parser.go +++ b/src/semantic-router/pkg/dsl/parser.go @@ -75,7 +75,6 @@ func Parse(input string) (*Program, []error) { prog.Models = append(prog.Models, resolved.Models...) prog.Plugins = append(prog.Plugins, resolved.Plugins...) prog.TestBlocks = append(prog.TestBlocks, resolved.TestBlocks...) - prog.SessionStates = append(prog.SessionStates, resolved.SessionStates...) allErrors = append(allErrors, lowerErrs...) } @@ -92,7 +91,7 @@ func splitTopLevelBlocks(input string) []string { var blocks []string depth := 0 start := 0 - keywords := []string{"SESSION_STATE", "DECISION_TREE", "PROJECTION", "SIGNAL", "ROUTE", "MODEL", "PLUGIN", "TEST"} + keywords := []string{"DECISION_TREE", "PROJECTION", "SIGNAL", "ROUTE", "MODEL", "PLUGIN", "TEST"} for i := 0; i < len(input); i++ { ch := input[i] @@ -182,8 +181,6 @@ func rawToProgram(raw *rawProgram) (*Program, []error) { prog.Plugins = append(prog.Plugins, rawToPlugin(entry.Plugin)) case entry.TestBlock != nil: prog.TestBlocks = append(prog.TestBlocks, rawToTestBlock(entry.TestBlock)) - case entry.SessionState != nil: - prog.SessionStates = append(prog.SessionStates, rawToSessionState(entry.SessionState)) } } if hasDirectRoutes && treeCount > 0 { @@ -345,30 +342,6 @@ func rawToTestBlock(r *rawTestBlockDecl) *TestBlockDecl { return tb } -func rawToSessionState(r *rawSessionStateDecl) *SessionStateDecl { - decl := &SessionStateDecl{ - Name: unquoteIdent(r.Name), - Pos: posFromLexer(r.Pos), - } - for _, entry := range r.Fields { - if entry == nil || entry.Value == nil { - continue - } - typeName := "" - // Bare identifiers (int, string, float) arrive via Val.BareStr. - if entry.Value.BareStr != nil { - typeName = *entry.Value.BareStr - } else if entry.Value.Str != nil { - typeName = unquote(*entry.Value.Str) - } - decl.Fields = append(decl.Fields, SessionStateField{ - Name: entry.Key, - TypeName: typeName, - }) - } - return decl -} - func rawToSignal(r *rawSignalDecl) *SignalDecl { return &SignalDecl{ SignalType: r.SignalType, diff --git a/src/semantic-router/pkg/dsl/routing_contract.go b/src/semantic-router/pkg/dsl/routing_contract.go index 44dabe6114..223beeec70 100644 --- a/src/semantic-router/pkg/dsl/routing_contract.go +++ b/src/semantic-router/pkg/dsl/routing_contract.go @@ -40,11 +40,6 @@ func DecompileRouting(cfg *config.RouterConfig) (string, error) { d.pluginTemplates = make(map[string]*pluginTemplate) d.extractPluginTemplates() - if len(cfg.SessionStates) > 0 { - d.writeSection("SESSION_STATES") - d.decompileSessionStates() - } - d.writeSection("SIGNALS") d.decompileSignals() @@ -69,26 +64,12 @@ func DecompileRouting(cfg *config.RouterConfig) (string, error) { func DecompileRoutingToAST(cfg *config.RouterConfig) *Program { d := &decompiler{cfg: cfg} prog := &Program{} - d.appendSessionStatesToProgram(prog) d.appendSignalsToProgram(prog) d.appendModelsToProgram(prog) d.appendRoutesToProgram(prog) return prog } -func (d *decompiler) appendSessionStatesToProgram(prog *Program) { - for _, ss := range d.cfg.SessionStates { - decl := &SessionStateDecl{Name: ss.Name} - for _, f := range ss.Fields { - decl.Fields = append(decl.Fields, SessionStateField{ - Name: f.Name, - TypeName: f.TypeName, - }) - } - prog.SessionStates = append(prog.SessionStates, decl) - } -} - func (d *decompiler) appendSignalsToProgram(prog *Program) { d.appendCoreSignals(prog) d.appendOperationalSignals(prog) @@ -198,6 +179,9 @@ func (d *decompiler) appendOperationalSignals(prog *Program) { for _, mod := range d.cfg.ModalityRules { prog.Signals = append(prog.Signals, d.modalityToSignal(&mod)) } + for _, session := range d.cfg.SessionRules { + prog.Signals = append(prog.Signals, d.sessionToSignal(&session)) + } for _, rb := range d.cfg.RoleBindings { prog.Signals = append(prog.Signals, d.roleBindingToSignal(&rb)) } diff --git a/src/semantic-router/pkg/dsl/validator_conflicts.go b/src/semantic-router/pkg/dsl/validator_conflicts.go index 4dbd7bfc63..c24ad25df6 100644 --- a/src/semantic-router/pkg/dsl/validator_conflicts.go +++ b/src/semantic-router/pkg/dsl/validator_conflicts.go @@ -17,47 +17,6 @@ func (v *Validator) checkConflicts() { v.checkProjections() v.checkTestBlocks() v.checkTierConstraints() - v.checkSessionStates() -} - -// checkSessionStates validates SESSION_STATE declarations for duplicate names, -// invalid field types, duplicate field names, and empty names. -func (v *Validator) checkSessionStates() { - seen := make(map[string]bool) - validTypes := map[string]bool{"int": true, "string": true, "float": true} - - for _, ss := range v.prog.SessionStates { - if ss.Name == "" { - v.addDiag(DiagConstraint, ss.Pos, "SESSION_STATE: name cannot be empty", nil) - continue - } - if seen[ss.Name] { - v.addDiag(DiagConstraint, ss.Pos, - fmt.Sprintf("SESSION_STATE %q: duplicate declaration name", ss.Name), nil) - continue - } - seen[ss.Name] = true - - fieldsSeen := make(map[string]bool) - for _, f := range ss.Fields { - if f.Name == "" { - v.addDiag(DiagConstraint, ss.Pos, - fmt.Sprintf("SESSION_STATE %q: field name cannot be empty", ss.Name), nil) - continue - } - if fieldsSeen[f.Name] { - v.addDiag(DiagConstraint, ss.Pos, - fmt.Sprintf("SESSION_STATE %q: duplicate field name %q", ss.Name, f.Name), nil) - continue - } - fieldsSeen[f.Name] = true - if !validTypes[f.TypeName] { - v.addDiag(DiagConstraint, ss.Pos, - fmt.Sprintf("SESSION_STATE %q: field %q has invalid type %q (supported: int, string, float)", - ss.Name, f.Name, f.TypeName), nil) - } - } - } } // checkDomainSignalOverlap detects MMLU category strings shared by two or more diff --git a/src/semantic-router/pkg/extproc/recorder.go b/src/semantic-router/pkg/extproc/recorder.go index 3130b1e5f5..d828853f31 100644 --- a/src/semantic-router/pkg/extproc/recorder.go +++ b/src/semantic-router/pkg/extproc/recorder.go @@ -157,24 +157,28 @@ func buildReplayRoutingRecord( guardrailsEnabled, jailbreakEnabled, piiEnabled, hallucinationEnabled := replayGuardrailState(ctx) decisionTier, decisionPriority := replayDecisionMetadata(ctx) record := routerreplay.RoutingRecord{ - RequestID: ctx.RequestID, - Decision: decisionName, - DecisionTier: decisionTier, - DecisionPriority: decisionPriority, - Category: ctx.VSRSelectedCategory, - OriginalModel: originalModel, - SelectedModel: replaySelectedModel(originalModel, selectedModel), - ReasoningMode: replayReasoningMode(ctx), - ConfidenceScore: ctx.VSRSelectedDecisionConfidence, - SelectionMethod: ctx.VSRSelectionMethod, - Signals: replaySignalState(ctx), - Projections: replayProjectionState(ctx), - ProjectionScores: cloneReplayFloat64Map(ctx.VSRProjectionScores), - SignalConfidences: cloneReplayFloat64Map(ctx.VSRSignalConfidences), - SignalValues: cloneReplayFloat64Map(ctx.VSRSignalValues), - ToolTrace: buildReplayRequestToolTrace(ctx), - Streaming: ctx.ExpectStreamingResponse, - FromCache: ctx.VSRCacheHit, + RequestID: ctx.RequestID, + Decision: decisionName, + DecisionTier: decisionTier, + DecisionPriority: decisionPriority, + Category: ctx.VSRSelectedCategory, + OriginalModel: originalModel, + SelectedModel: replaySelectedModel(originalModel, selectedModel), + ReasoningMode: replayReasoningMode(ctx), + ConfidenceScore: ctx.VSRSelectedDecisionConfidence, + SelectionMethod: ctx.VSRSelectionMethod, + SessionID: ctx.SessionID, + TurnIndex: ctx.TurnIndex, + PreviousModel: ctx.PreviousModel, + CacheWarmthEstimate: ctx.CacheWarmthEstimate, + Signals: replaySignalState(ctx), + Projections: replayProjectionState(ctx), + ProjectionScores: cloneReplayFloat64Map(ctx.VSRProjectionScores), + SignalConfidences: cloneReplayFloat64Map(ctx.VSRSignalConfidences), + SignalValues: cloneReplayFloat64Map(ctx.VSRSignalValues), + ToolTrace: buildReplayRequestToolTrace(ctx), + Streaming: ctx.ExpectStreamingResponse, + FromCache: ctx.VSRCacheHit, GuardrailsEnabled: guardrailsEnabled, JailbreakEnabled: jailbreakEnabled, @@ -239,6 +243,7 @@ func replaySignalState(ctx *RequestContext) routerreplay.Signal { Structure: ctx.VSRMatchedStructure, Complexity: ctx.VSRMatchedComplexity, Modality: ctx.VSRMatchedModality, + Session: ctx.VSRMatchedSession, Authz: ctx.VSRMatchedAuthz, Jailbreak: ctx.VSRMatchedJailbreak, PII: ctx.VSRMatchedPII, diff --git a/src/semantic-router/pkg/extproc/recorder_test.go b/src/semantic-router/pkg/extproc/recorder_test.go index c3ea5f0041..6445fbaecb 100644 --- a/src/semantic-router/pkg/extproc/recorder_test.go +++ b/src/semantic-router/pkg/extproc/recorder_test.go @@ -16,6 +16,10 @@ func TestBuildReplayRoutingRecordCapturesRoutingMetadata(t *testing.T) { VSRSelectionMethod: "router_dc", VSRCacheHit: true, ExpectStreamingResponse: true, + SessionID: "session-123", + TurnIndex: 4, + PreviousModel: "model-a", + CacheWarmthEstimate: 0.84, VSRSelectedDecision: &config.Decision{ Name: "balance", Tier: 3, @@ -23,6 +27,7 @@ func TestBuildReplayRoutingRecordCapturesRoutingMetadata(t *testing.T) { }, VSRMatchedKeywords: []string{"math_keyword"}, VSRMatchedModality: []string{"AR"}, + VSRMatchedSession: []string{"session_present"}, VSRMatchedAuthz: []string{"premium_tier"}, VSRMatchedJailbreak: []string{"jailbreak_detector"}, VSRMatchedPII: []string{"email_block"}, @@ -58,9 +63,24 @@ func TestBuildReplayRoutingRecordCapturesRoutingMetadata(t *testing.T) { if got := record.SignalValues["reask:persistently_dissatisfied"]; got != 2 { t.Fatalf("expected signal value 2, got %v", got) } + if record.SessionID != "session-123" { + t.Fatalf("expected session ID=session-123, got %q", record.SessionID) + } + if record.TurnIndex != 4 { + t.Fatalf("expected turn index=4, got %d", record.TurnIndex) + } + if record.PreviousModel != "model-a" { + t.Fatalf("expected previous model=model-a, got %q", record.PreviousModel) + } + if record.CacheWarmthEstimate != 0.84 { + t.Fatalf("expected cache warmth=0.84, got %v", record.CacheWarmthEstimate) + } if !reflect.DeepEqual(record.Signals.Modality, []string{"AR"}) { t.Fatalf("unexpected modality signals: %#v", record.Signals.Modality) } + if !reflect.DeepEqual(record.Signals.Session, []string{"session_present"}) { + t.Fatalf("unexpected session signals: %#v", record.Signals.Session) + } if !reflect.DeepEqual(record.Signals.Authz, []string{"premium_tier"}) { t.Fatalf("unexpected authz signals: %#v", record.Signals.Authz) } diff --git a/src/semantic-router/pkg/extproc/req_filter_classification.go b/src/semantic-router/pkg/extproc/req_filter_classification.go index 77ab417203..fe018d3541 100644 --- a/src/semantic-router/pkg/extproc/req_filter_classification.go +++ b/src/semantic-router/pkg/extproc/req_filter_classification.go @@ -63,7 +63,7 @@ func (r *OpenAIRouter) performDecisionEvaluation(originalModel string, history s // The algorithm parameter allows per-decision algorithm override (aligned with looper pattern). // The categoryName parameter is the detected domain category (e.g., "physics", "math") for ML feature vectors. // Returns the selected model and the method name used for logging. -func (r *OpenAIRouter) selectModelFromCandidates(modelRefs []config.ModelRef, decisionName string, query string, algorithm *config.AlgorithmConfig, categoryName string) (*config.ModelRef, string) { +func (r *OpenAIRouter) selectModelFromCandidates(ctx *RequestContext, modelRefs []config.ModelRef, decisionName string, query string, originalModel string, algorithm *config.AlgorithmConfig, categoryName string) (*config.ModelRef, string) { if len(modelRefs) == 0 { return nil, "" } @@ -88,20 +88,8 @@ func (r *OpenAIRouter) selectModelFromCandidates(modelRefs []config.ModelRef, de return &modelRefs[0], string(method) } - // Build selection context with cost/quality weights - costWeight, qualityWeight := r.getSelectionWeights(algorithm) - latencyAwareTPOTPercentile, latencyAwareTTFTPercentile := r.getLatencyAwarePercentiles(algorithm) - - selCtx := &selection.SelectionContext{ - Query: query, - DecisionName: decisionName, - CategoryName: categoryName, - CandidateModels: modelRefs, - CostWeight: costWeight, - QualityWeight: qualityWeight, - LatencyAwareTPOTPercentile: latencyAwareTPOTPercentile, - LatencyAwareTTFTPercentile: latencyAwareTTFTPercentile, - } + // Build selection context with request-time session facts and lookup-table priors. + selCtx := r.buildSelectionContext(ctx, modelRefs, decisionName, query, originalModel, algorithm, categoryName) // Perform selection result, err := selector.Select(context.Background(), selCtx) diff --git a/src/semantic-router/pkg/extproc/req_filter_classification_runtime.go b/src/semantic-router/pkg/extproc/req_filter_classification_runtime.go index 49a357805f..24df06ef89 100644 --- a/src/semantic-router/pkg/extproc/req_filter_classification_runtime.go +++ b/src/semantic-router/pkg/extproc/req_filter_classification_runtime.go @@ -19,6 +19,7 @@ var selectionMethodByAlgorithmType = map[string]selection.SelectionMethod{ "router_dc": selection.MethodRouterDC, "automix": selection.MethodAutoMix, "hybrid": selection.MethodHybrid, + "session_aware": selection.MethodSessionAware, "rl_driven": selection.MethodRLDriven, "gmtrouter": selection.MethodGMTRouter, "latency_aware": selection.MethodLatencyAware, @@ -61,6 +62,7 @@ func (r *OpenAIRouter) evaluateSignalsForDecision( } signalLatency := time.Since(signalStart).Milliseconds() + r.applyRuntimeSessionSignals(ctx, signals) r.applySignalResultsToContext(ctx, signals) logSignalEvaluationResults(ctx, signalLatency, signals) tracing.EndSignalSpan(signalSpan, collectMatchedSignalRules(signals), 1.0, signalLatency) @@ -85,6 +87,7 @@ func logSignalEvaluationResults(ctx *RequestContext, signalLatencyMs int64, sign "structure": signals.MatchedStructureRules, "complexity": signals.MatchedComplexityRules, "modality": signals.MatchedModalityRules, + "session": signals.MatchedSessionRules, "authz": signals.MatchedAuthzRules, "jailbreak": signals.MatchedJailbreakRules, "pii": signals.MatchedPIIRules, @@ -216,9 +219,11 @@ func (r *OpenAIRouter) selectDecisionRuntimeModel( } selectedModelRef, usedMethod := r.selectModelFromCandidates( + ctx, result.Decision.ModelRefs, decisionName, userContent, + ctx.RequestModel, result.Decision.Algorithm, categoryName, ) diff --git a/src/semantic-router/pkg/extproc/req_filter_classification_session.go b/src/semantic-router/pkg/extproc/req_filter_classification_session.go new file mode 100644 index 0000000000..6adadd01f9 --- /dev/null +++ b/src/semantic-router/pkg/extproc/req_filter_classification_session.go @@ -0,0 +1,98 @@ +package extproc + +import ( + "strings" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/classification" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +func (r *OpenAIRouter) applyRuntimeSessionSignals(ctx *RequestContext, signals *classification.SignalResults) { + if r == nil || ctx == nil || signals == nil || len(r.Config.SessionRules) == 0 { + return + } + if signals.SignalValues == nil { + signals.SignalValues = make(map[string]float64) + } + if signals.SignalConfidences == nil { + signals.SignalConfidences = make(map[string]float64) + } + + taskFamily := "" + if len(signals.MatchedDomainRules) > 0 { + taskFamily = strings.TrimSpace(signals.MatchedDomainRules[0]) + } + + for _, rule := range r.Config.SessionRules { + value := r.resolveSessionRuleValue(rule, ctx, taskFamily) + key := config.SignalTypeSession + ":" + rule.Name + signals.SignalValues[key] = value + signals.SignalConfidences[key] = 1.0 + if numericPredicateMatches(rule.Predicate, value) { + signals.MatchedSessionRules = append(signals.MatchedSessionRules, rule.Name) + } + } +} + +func (r *OpenAIRouter) resolveSessionRuleValue(rule config.SessionRule, ctx *RequestContext, taskFamily string) float64 { + if ctx == nil { + return 0 + } + currentModel := strings.TrimSpace(ctx.PreviousModel) + candidateModel := strings.TrimSpace(rule.CandidateModel) + resolvedTaskFamily := strings.TrimSpace(rule.IntentOrDomain) + if resolvedTaskFamily == "" { + resolvedTaskFamily = strings.TrimSpace(taskFamily) + } + if currentModel == "" { + currentModel = strings.TrimSpace(rule.PreviousModel) + } + + switch config.NormalizeSessionFact(rule.Fact) { + case config.SessionFactSessionPresent: + return boolToFloat(ctx.SessionID != "") + case config.SessionFactHasPreviousModel: + if strings.TrimSpace(rule.PreviousModel) != "" { + return boolToFloat(strings.EqualFold(strings.TrimSpace(ctx.PreviousModel), strings.TrimSpace(rule.PreviousModel))) + } + return boolToFloat(strings.TrimSpace(ctx.PreviousModel) != "") + case config.SessionFactTurnIndex: + return float64(ctx.TurnIndex) + case config.SessionFactCacheWarmth: + return ctx.CacheWarmthEstimate + case config.SessionFactRemainingTurns: + return r.lookupRemainingTurns(resolvedTaskFamily, ctx.TurnIndex) + case config.SessionFactHandoffPenalty: + return r.lookupHandoffPenalty(currentModel, candidateModel) + case config.SessionFactQualityGap: + return r.lookupQualityGap(resolvedTaskFamily, currentModel, candidateModel) + default: + return 0 + } +} + +func numericPredicateMatches(predicate *config.NumericPredicate, value float64) bool { + if predicate == nil { + return false + } + if predicate.GT != nil && !(value > *predicate.GT) { + return false + } + if predicate.GTE != nil && !(value >= *predicate.GTE) { + return false + } + if predicate.LT != nil && !(value < *predicate.LT) { + return false + } + if predicate.LTE != nil && !(value <= *predicate.LTE) { + return false + } + return true +} + +func boolToFloat(value bool) float64 { + if value { + return 1 + } + return 0 +} diff --git a/src/semantic-router/pkg/extproc/req_filter_classification_signal.go b/src/semantic-router/pkg/extproc/req_filter_classification_signal.go index 7e4f1431c8..3e9d9837ff 100644 --- a/src/semantic-router/pkg/extproc/req_filter_classification_signal.go +++ b/src/semantic-router/pkg/extproc/req_filter_classification_signal.go @@ -89,6 +89,7 @@ func (r *OpenAIRouter) applySignalResultsToContext(ctx *RequestContext, signals ctx.VSRMatchedStructure = signals.MatchedStructureRules ctx.VSRMatchedComplexity = signals.MatchedComplexityRules ctx.VSRMatchedModality = signals.MatchedModalityRules + ctx.VSRMatchedSession = signals.MatchedSessionRules ctx.VSRMatchedAuthz = signals.MatchedAuthzRules ctx.VSRMatchedJailbreak = signals.MatchedJailbreakRules ctx.VSRMatchedPII = signals.MatchedPIIRules @@ -137,6 +138,7 @@ func collectMatchedSignalRules(signals *classification.SignalResults) []string { allMatchedRules = append(allMatchedRules, signals.MatchedStructureRules...) allMatchedRules = append(allMatchedRules, signals.MatchedComplexityRules...) allMatchedRules = append(allMatchedRules, signals.MatchedModalityRules...) + allMatchedRules = append(allMatchedRules, signals.MatchedSessionRules...) allMatchedRules = append(allMatchedRules, signals.MatchedAuthzRules...) allMatchedRules = append(allMatchedRules, signals.MatchedJailbreakRules...) allMatchedRules = append(allMatchedRules, signals.MatchedPIIRules...) diff --git a/src/semantic-router/pkg/extproc/request_context.go b/src/semantic-router/pkg/extproc/request_context.go index 84dbdcf443..246798a3b7 100644 --- a/src/semantic-router/pkg/extproc/request_context.go +++ b/src/semantic-router/pkg/extproc/request_context.go @@ -108,6 +108,7 @@ type RequestContext struct { VSRMatchedStructure []string // Matched structure rule names VSRMatchedComplexity []string // Matched complexity rules with difficulty level (e.g. "code_complexity:hard") VSRMatchedModality []string // Matched modality signals: "AR", "DIFFUSION", or "BOTH" + VSRMatchedSession []string // Matched runtime-derived session signal names VSRMatchedAuthz []string // Matched authz rule names for user-level routing VSRMatchedJailbreak []string // Matched jailbreak rule names (confidence >= threshold) VSRMatchedPII []string // Matched PII rule names (denied PII types detected) diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index e0756a28b9..92f59a5092 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -17,6 +17,7 @@ import ( "github.com/vllm-project/semantic-router/src/semantic-router/pkg/routerreplay" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/routerruntime" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection/lookuptable" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/services" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/tools" ) @@ -34,7 +35,8 @@ type OpenAIRouter struct { ReplayStoreShared bool // ModelSelector is the registry of advanced model selection algorithms // initialized from config.IntelligentRouting.ModelSelection. - ModelSelector *selection.Registry + ModelSelector *selection.Registry + LookupTable lookuptable.LookupTable ReplayRecorders map[string]*routerreplay.Recorder MemoryStore memory.Store MemoryExtractor *memory.MemoryExtractor diff --git a/src/semantic-router/pkg/extproc/router_build.go b/src/semantic-router/pkg/extproc/router_build.go index e5c07064e1..c63abf1a58 100644 --- a/src/semantic-router/pkg/extproc/router_build.go +++ b/src/semantic-router/pkg/extproc/router_build.go @@ -13,6 +13,7 @@ import ( "github.com/vllm-project/semantic-router/src/semantic-router/pkg/routerreplay" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/routerreplay/store" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection/lookuptable" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/services" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/tools" ) @@ -35,6 +36,7 @@ type routerComponents struct { replayStoreShared bool replayRecorders map[string]*routerreplay.Recorder modelSelector *selection.Registry + lookupTable lookuptable.LookupTable memoryStore memory.Store memoryExtractor *memory.MemoryExtractor credentialResolver *authz.CredentialResolver @@ -129,7 +131,7 @@ func buildRouterComponents(cfg *config.RouterConfig) (*routerComponents, error) if replayRecorder != nil { replayReaderForLookup = replayRecorder.Reader() } - modelSelector, _, lookupTableCancel := createModelSelectorRegistry(cfg, replayReaderForLookup) + modelSelector, lookupTable, lookupTableCancel := createModelSelectorRegistry(cfg, replayReaderForLookup) memoryStore, memoryExtractor := createMemoryRuntime(cfg) credentialResolver := buildCredentialResolver(cfg) rateLimiter := buildRateLimitResolver(cfg) @@ -157,6 +159,7 @@ func buildRouterComponents(cfg *config.RouterConfig) (*routerComponents, error) replayStoreShared: replayStoreShared, replayRecorders: replayRecorders, modelSelector: modelSelector, + lookupTable: lookupTable, memoryStore: memoryStore, memoryExtractor: memoryExtractor, credentialResolver: credentialResolver, @@ -177,6 +180,7 @@ func (components *routerComponents) buildRouter() *OpenAIRouter { ReplayRecorder: components.replayRecorder, ReplayStoreShared: components.replayStoreShared, ModelSelector: components.modelSelector, + LookupTable: components.lookupTable, ReplayRecorders: components.replayRecorders, MemoryStore: components.memoryStore, MemoryExtractor: components.memoryExtractor, diff --git a/src/semantic-router/pkg/extproc/router_selection.go b/src/semantic-router/pkg/extproc/router_selection.go index 6cef6ceade..1586ee3bdc 100644 --- a/src/semantic-router/pkg/extproc/router_selection.go +++ b/src/semantic-router/pkg/extproc/router_selection.go @@ -74,9 +74,10 @@ func buildModelSelectionConfig(cfg *config.RouterConfig) *selection.ModelSelecti Method: "static", } - eloFromDecision, routerDCFromDecision := findDecisionScopedSelectionConfigs(cfg) + eloFromDecision, routerDCFromDecision, sessionAwareFromDecision := findDecisionScopedSelectionConfigs(cfg) modelSelectionCfg.Elo = buildEloSelectionConfig(cfg, eloFromDecision) modelSelectionCfg.RouterDC = buildRouterDCSelectionConfig(cfg, routerDCFromDecision) + modelSelectionCfg.SessionAware = buildSessionAwareSelectionConfig(cfg, sessionAwareFromDecision) modelSelectionCfg.AutoMix = buildAutoMixSelectionConfig(cfg) modelSelectionCfg.Hybrid = buildHybridSelectionConfig(cfg) modelSelectionCfg.ML = buildMLSelectionConfig(cfg) @@ -85,10 +86,11 @@ func buildModelSelectionConfig(cfg *config.RouterConfig) *selection.ModelSelecti func findDecisionScopedSelectionConfigs( cfg *config.RouterConfig, -) (*config.EloSelectionConfig, *config.RouterDCSelectionConfig) { +) (*config.EloSelectionConfig, *config.RouterDCSelectionConfig, *config.SessionAwareSelectionConfig) { intelligentRouting := cfg.IntelligentRouting var eloFromDecision *config.EloSelectionConfig var routerDCFromDecision *config.RouterDCSelectionConfig + var sessionAwareFromDecision *config.SessionAwareSelectionConfig for _, decision := range intelligentRouting.Decisions { if decision.Algorithm == nil { @@ -104,9 +106,14 @@ func findDecisionScopedSelectionConfigs( routerDCFromDecision == nil { routerDCFromDecision = decision.Algorithm.RouterDC } + if decision.Algorithm.Type == "session_aware" && + decision.Algorithm.SessionAware != nil && + sessionAwareFromDecision == nil { + sessionAwareFromDecision = decision.Algorithm.SessionAware + } } - return eloFromDecision, routerDCFromDecision + return eloFromDecision, routerDCFromDecision, sessionAwareFromDecision } func buildEloSelectionConfig( @@ -200,6 +207,44 @@ func buildHybridSelectionConfig(cfg *config.RouterConfig) *selection.HybridConfi } } +func buildSessionAwareSelectionConfig( + cfg *config.RouterConfig, + decisionCfg *config.SessionAwareSelectionConfig, +) *selection.SessionAwareConfig { + intelligentRouting := cfg.IntelligentRouting + sessionCfg := intelligentRouting.ModelSelection.SessionAware + result := &selection.SessionAwareConfig{ + FallbackMethod: sessionCfg.FallbackMethod, + MinTurnsBeforeSwitch: sessionCfg.MinTurnsBeforeSwitch, + StayBias: sessionCfg.StayBias, + QualityGapMultiplier: sessionCfg.QualityGapMultiplier, + HandoffPenaltyWeight: sessionCfg.HandoffPenaltyWeight, + RemainingTurnWeight: sessionCfg.RemainingTurnWeight, + } + if decisionCfg == nil { + return result + } + if decisionCfg.FallbackMethod != "" { + result.FallbackMethod = decisionCfg.FallbackMethod + } + if decisionCfg.MinTurnsBeforeSwitch != 0 { + result.MinTurnsBeforeSwitch = decisionCfg.MinTurnsBeforeSwitch + } + if decisionCfg.StayBias != 0 { + result.StayBias = decisionCfg.StayBias + } + if decisionCfg.QualityGapMultiplier != 0 { + result.QualityGapMultiplier = decisionCfg.QualityGapMultiplier + } + if decisionCfg.HandoffPenaltyWeight != 0 { + result.HandoffPenaltyWeight = decisionCfg.HandoffPenaltyWeight + } + if decisionCfg.RemainingTurnWeight != 0 { + result.RemainingTurnWeight = decisionCfg.RemainingTurnWeight + } + return result +} + func buildMLSelectionConfig(cfg *config.RouterConfig) *selection.MLSelectorConfig { intelligentRouting := cfg.IntelligentRouting mlCfg := intelligentRouting.ModelSelection.ML diff --git a/src/semantic-router/pkg/extproc/router_selection_context.go b/src/semantic-router/pkg/extproc/router_selection_context.go new file mode 100644 index 0000000000..f2837da569 --- /dev/null +++ b/src/semantic-router/pkg/extproc/router_selection_context.go @@ -0,0 +1,111 @@ +package extproc + +import ( + "strings" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" +) + +func (r *OpenAIRouter) buildSelectionContext( + ctx *RequestContext, + modelRefs []config.ModelRef, + decisionName string, + query string, + originalModel string, + algorithm *config.AlgorithmConfig, + categoryName string, +) *selection.SelectionContext { + costWeight, qualityWeight := r.getSelectionWeights(algorithm) + latencyAwareTPOTPercentile, latencyAwareTTFTPercentile := r.getLatencyAwarePercentiles(algorithm) + taskFamily := resolveSelectionTaskFamily(categoryName, decisionName) + currentModel := strings.TrimSpace(ctx.PreviousModel) + qualityGapByCandidate, handoffPenaltyByCandidate := r.lookupCandidateSessionFacts(taskFamily, currentModel, modelRefs) + + return &selection.SelectionContext{ + Query: query, + DecisionName: decisionName, + CategoryName: categoryName, + CandidateModels: modelRefs, + CostWeight: costWeight, + QualityWeight: qualityWeight, + UserID: extractUserID(ctx), + SessionID: ctx.SessionID, + TurnIndex: ctx.TurnIndex, + PreviousModel: ctx.PreviousModel, + CurrentModel: currentModel, + OriginalModel: originalModel, + CacheWarmthEstimate: ctx.CacheWarmthEstimate, + RemainingTurnsEstimate: r.lookupRemainingTurns(taskFamily, ctx.TurnIndex), + QualityGapByCandidate: qualityGapByCandidate, + HandoffPenaltyByCandidate: handoffPenaltyByCandidate, + LatencyAwareTPOTPercentile: latencyAwareTPOTPercentile, + LatencyAwareTTFTPercentile: latencyAwareTTFTPercentile, + } +} + +func resolveSelectionTaskFamily(categoryName string, decisionName string) string { + if trimmed := strings.TrimSpace(categoryName); trimmed != "" { + return trimmed + } + return strings.TrimSpace(decisionName) +} + +func (r *OpenAIRouter) lookupCandidateSessionFacts( + taskFamily string, + currentModel string, + modelRefs []config.ModelRef, +) (map[string]float64, map[string]float64) { + qualityGapByCandidate := make(map[string]float64) + handoffPenaltyByCandidate := make(map[string]float64) + if r == nil || r.LookupTable == nil || strings.TrimSpace(currentModel) == "" { + return qualityGapByCandidate, handoffPenaltyByCandidate + } + + for _, ref := range modelRefs { + candidateModel := strings.TrimSpace(ref.Model) + if candidateModel == "" || candidateModel == strings.TrimSpace(currentModel) { + continue + } + qualityGapByCandidate[candidateModel] = r.lookupQualityGap(taskFamily, currentModel, candidateModel) + handoffPenaltyByCandidate[candidateModel] = r.lookupHandoffPenalty(currentModel, candidateModel) + } + return qualityGapByCandidate, handoffPenaltyByCandidate +} + +func (r *OpenAIRouter) lookupRemainingTurns(taskFamily string, turnIndex int) float64 { + if r == nil || r.LookupTable == nil || strings.TrimSpace(taskFamily) == "" { + return 0 + } + prior, ok := r.LookupTable.RemainingTurnPrior(strings.TrimSpace(taskFamily)) + if !ok { + return 0 + } + remaining := prior - float64(turnIndex) + if remaining < 0 { + return 0 + } + return remaining +} + +func (r *OpenAIRouter) lookupQualityGap(taskFamily string, currentModel string, candidateModel string) float64 { + if r == nil || r.LookupTable == nil || strings.TrimSpace(taskFamily) == "" || strings.TrimSpace(currentModel) == "" || strings.TrimSpace(candidateModel) == "" { + return 0 + } + value, ok := r.LookupTable.QualityGap(strings.TrimSpace(taskFamily), strings.TrimSpace(currentModel), strings.TrimSpace(candidateModel)) + if !ok { + return 0 + } + return value +} + +func (r *OpenAIRouter) lookupHandoffPenalty(currentModel string, candidateModel string) float64 { + if r == nil || r.LookupTable == nil || strings.TrimSpace(currentModel) == "" || strings.TrimSpace(candidateModel) == "" { + return 0 + } + value, ok := r.LookupTable.HandoffPenalty(strings.TrimSpace(currentModel), strings.TrimSpace(candidateModel)) + if !ok { + return 0 + } + return value +} diff --git a/src/semantic-router/pkg/k8s/testdata/output/01-basic.yaml b/src/semantic-router/pkg/k8s/testdata/output/01-basic.yaml index 380f2c74a5..a5f97a1a2a 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/01-basic.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/01-basic.yaml @@ -158,6 +158,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/02-keyword-only.yaml b/src/semantic-router/pkg/k8s/testdata/output/02-keyword-only.yaml index 14b816f7c5..e35edb0018 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/02-keyword-only.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/02-keyword-only.yaml @@ -141,6 +141,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/03-embedding-only.yaml b/src/semantic-router/pkg/k8s/testdata/output/03-embedding-only.yaml index ea74bdff45..3b1340ff23 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/03-embedding-only.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/03-embedding-only.yaml @@ -142,6 +142,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/04-domain-only.yaml b/src/semantic-router/pkg/k8s/testdata/output/04-domain-only.yaml index fe6fd40b5a..5dbc62beba 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/04-domain-only.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/04-domain-only.yaml @@ -147,6 +147,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/05-keyword-embedding.yaml b/src/semantic-router/pkg/k8s/testdata/output/05-keyword-embedding.yaml index 52ef1bc15b..ecb222b8a5 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/05-keyword-embedding.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/05-keyword-embedding.yaml @@ -147,6 +147,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/06-keyword-domain.yaml b/src/semantic-router/pkg/k8s/testdata/output/06-keyword-domain.yaml index ad15b51b55..cd8ad0e649 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/06-keyword-domain.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/06-keyword-domain.yaml @@ -153,6 +153,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/07-domain-embedding.yaml b/src/semantic-router/pkg/k8s/testdata/output/07-domain-embedding.yaml index 243c7f0b79..854988131d 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/07-domain-embedding.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/07-domain-embedding.yaml @@ -147,6 +147,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/08-keyword-embedding-domain.yaml b/src/semantic-router/pkg/k8s/testdata/output/08-keyword-embedding-domain.yaml index 7ba8f715d9..75534c4e19 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/08-keyword-embedding-domain.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/08-keyword-embedding-domain.yaml @@ -159,6 +159,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/09-keyword-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/09-keyword-plugin.yaml index 22238f2e21..b8ee0b3890 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/09-keyword-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/09-keyword-plugin.yaml @@ -122,6 +122,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/10-embedding-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/10-embedding-plugin.yaml index e34b0a93ee..c397868f0f 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/10-embedding-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/10-embedding-plugin.yaml @@ -123,6 +123,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/11-domain-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/11-domain-plugin.yaml index d93efff78e..c1f3d0f624 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/11-domain-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/11-domain-plugin.yaml @@ -123,6 +123,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/12-keyword-embedding-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/12-keyword-embedding-plugin.yaml index 3941bc86e0..577ffcd0e7 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/12-keyword-embedding-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/12-keyword-embedding-plugin.yaml @@ -135,6 +135,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/13-keyword-domain-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/13-keyword-domain-plugin.yaml index 1ea924d3f1..71c6035f36 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/13-keyword-domain-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/13-keyword-domain-plugin.yaml @@ -136,6 +136,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/14-domain-embedding-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/14-domain-embedding-plugin.yaml index 56917a18d1..70e5bd440a 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/14-domain-embedding-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/14-domain-embedding-plugin.yaml @@ -133,6 +133,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/15-keyword-embedding-domain-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/15-keyword-embedding-domain-plugin.yaml index 8d2a1641f2..df638882c0 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/15-keyword-embedding-domain-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/15-keyword-embedding-domain-plugin.yaml @@ -201,6 +201,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/k8s/testdata/output/16-keyword-embedding-domain-no-plugin.yaml b/src/semantic-router/pkg/k8s/testdata/output/16-keyword-embedding-domain-no-plugin.yaml index d41ad4d9ca..7f9e406922 100644 --- a/src/semantic-router/pkg/k8s/testdata/output/16-keyword-embedding-domain-no-plugin.yaml +++ b/src/semantic-router/pkg/k8s/testdata/output/16-keyword-embedding-domain-no-plugin.yaml @@ -175,6 +175,8 @@ global: authz: {} ratelimit: {} router_replay: {} + startup_status: + store_backend: "" stores: semantic_cache: backend_type: memory diff --git a/src/semantic-router/pkg/routerreplay/recorder.go b/src/semantic-router/pkg/routerreplay/recorder.go index eced717200..9fb61407b6 100644 --- a/src/semantic-router/pkg/routerreplay/recorder.go +++ b/src/semantic-router/pkg/routerreplay/recorder.go @@ -173,6 +173,7 @@ func logSignalFields(signals Signal) map[string]interface{} { "structure": signals.Structure, "complexity": signals.Complexity, "modality": signals.Modality, + "session": signals.Session, "authz": signals.Authz, "jailbreak": signals.Jailbreak, "pii": signals.PII, @@ -259,23 +260,27 @@ func appendUsageCostLogFields(fields map[string]interface{}, r RoutingRecord) { func LogFields(r RoutingRecord, event string) map[string]interface{} { fields := map[string]interface{}{ - "event": event, - "replay_id": r.ID, - "decision": r.Decision, - "decision_tier": r.DecisionTier, - "decision_priority": r.DecisionPriority, - "category": r.Category, - "original_model": r.OriginalModel, - "selected_model": r.SelectedModel, - "reasoning_mode": r.ReasoningMode, - "confidence_score": r.ConfidenceScore, - "selection_method": r.SelectionMethod, - "request_id": r.RequestID, - "timestamp": r.Timestamp, - "from_cache": r.FromCache, - "streaming": r.Streaming, - "response_status": r.ResponseStatus, - "signals": logSignalFields(r.Signals), + "event": event, + "replay_id": r.ID, + "decision": r.Decision, + "decision_tier": r.DecisionTier, + "decision_priority": r.DecisionPriority, + "category": r.Category, + "original_model": r.OriginalModel, + "selected_model": r.SelectedModel, + "reasoning_mode": r.ReasoningMode, + "confidence_score": r.ConfidenceScore, + "selection_method": r.SelectionMethod, + "session_id": r.SessionID, + "turn_index": r.TurnIndex, + "previous_model": r.PreviousModel, + "cache_warmth_estimate": r.CacheWarmthEstimate, + "request_id": r.RequestID, + "timestamp": r.Timestamp, + "from_cache": r.FromCache, + "streaming": r.Streaming, + "response_status": r.ResponseStatus, + "signals": logSignalFields(r.Signals), } if len(r.Projections) > 0 { fields["projections"] = r.Projections diff --git a/src/semantic-router/pkg/routerreplay/recorder_test.go b/src/semantic-router/pkg/routerreplay/recorder_test.go index 0ea254ba31..b447cbddfd 100644 --- a/src/semantic-router/pkg/routerreplay/recorder_test.go +++ b/src/semantic-router/pkg/routerreplay/recorder_test.go @@ -142,6 +142,10 @@ func TestLogFieldsIncludesOptionalReplayMetadata(t *testing.T) { assertFieldValue(t, fields, "decision_tier", 2) assertFieldValue(t, fields, "decision_priority", 100) assertFieldValue(t, fields, "selection_method", "router_dc") + assertFieldValue(t, fields, "session_id", "session-123") + assertFieldValue(t, fields, "turn_index", 4) + assertFieldValue(t, fields, "previous_model", "model-a") + assertFieldValue(t, fields, "cache_warmth_estimate", 0.84) assertFieldValue(t, fields, "guardrails_enabled", true) assertFieldValue(t, fields, "jailbreak_type", "prompt_injection") assertFieldValue(t, fields, "pii_entities", []string{"email"}) @@ -168,25 +172,29 @@ func richReplayRoutingRecord( baselineModel *string, ) RoutingRecord { return RoutingRecord{ - ID: "replay-1", - Decision: "decision-a", - DecisionTier: 2, - DecisionPriority: 100, - Category: "math", - OriginalModel: "model-a", - SelectedModel: "model-b", - ReasoningMode: "cot", - ConfidenceScore: 0.91, - SelectionMethod: "router_dc", - RequestID: "req-1", - Timestamp: timestamp, - FromCache: true, - Streaming: true, - ResponseStatus: 200, - Projections: []string{"balance_reasoning"}, - ProjectionScores: map[string]float64{"reasoning_pressure": 0.73}, - SignalConfidences: map[string]float64{"projection:balance_reasoning": 0.73}, - SignalValues: map[string]float64{"reask:likely_dissatisfied": 2}, + ID: "replay-1", + Decision: "decision-a", + DecisionTier: 2, + DecisionPriority: 100, + Category: "math", + OriginalModel: "model-a", + SelectedModel: "model-b", + ReasoningMode: "cot", + ConfidenceScore: 0.91, + SelectionMethod: "router_dc", + SessionID: "session-123", + TurnIndex: 4, + PreviousModel: "model-a", + CacheWarmthEstimate: 0.84, + RequestID: "req-1", + Timestamp: timestamp, + FromCache: true, + Streaming: true, + ResponseStatus: 200, + Projections: []string{"balance_reasoning"}, + ProjectionScores: map[string]float64{"reasoning_pressure": 0.73}, + SignalConfidences: map[string]float64{"projection:balance_reasoning": 0.73}, + SignalValues: map[string]float64{"reask:likely_dissatisfied": 2}, ToolTrace: &ToolTrace{ Flow: "User Query -> LLM Tool Call -> Client Tool Result -> LLM Final Response", Stage: "LLM Final Response", @@ -203,6 +211,7 @@ func richReplayRoutingRecord( Reask: []string{"likely_dissatisfied"}, Complexity: []string{"complex"}, Modality: []string{"AR"}, + Session: []string{"session_present"}, Authz: []string{"premium_tier"}, Jailbreak: []string{"prompt_attack"}, PII: []string{"email"}, @@ -249,6 +258,7 @@ func assertSignalLogFields(t *testing.T, fields map[string]interface{}) { assertFieldValue(t, signals, "reask", []string{"likely_dissatisfied"}) assertFieldValue(t, signals, "complexity", []string{"complex"}) assertFieldValue(t, signals, "modality", []string{"AR"}) + assertFieldValue(t, signals, "session", []string{"session_present"}) assertFieldValue(t, signals, "authz", []string{"premium_tier"}) assertFieldValue(t, signals, "jailbreak", []string{"prompt_attack"}) assertFieldValue(t, signals, "pii", []string{"email"}) diff --git a/src/semantic-router/pkg/routerreplay/store/postgres.go b/src/semantic-router/pkg/routerreplay/store/postgres.go index 1d5278b616..196e0560a9 100644 --- a/src/semantic-router/pkg/routerreplay/store/postgres.go +++ b/src/semantic-router/pkg/routerreplay/store/postgres.go @@ -78,6 +78,12 @@ func (p *PostgresStore) createTable(ctx context.Context) error { original_model VARCHAR(255), selected_model VARCHAR(255), reasoning_mode VARCHAR(255), + confidence_score DOUBLE PRECISION DEFAULT 0, + selection_method VARCHAR(255), + session_id VARCHAR(255), + turn_index INTEGER DEFAULT 0, + previous_model VARCHAR(255), + cache_warmth_estimate DOUBLE PRECISION DEFAULT 0, signals JSONB, projections JSONB, projection_scores JSONB, @@ -114,6 +120,12 @@ func (p *PostgresStore) createTable(ctx context.Context) error { ); ALTER TABLE %s ADD COLUMN IF NOT EXISTS decision_tier INTEGER DEFAULT 0; ALTER TABLE %s ADD COLUMN IF NOT EXISTS decision_priority INTEGER DEFAULT 0; + ALTER TABLE %s ADD COLUMN IF NOT EXISTS confidence_score DOUBLE PRECISION DEFAULT 0; + ALTER TABLE %s ADD COLUMN IF NOT EXISTS selection_method VARCHAR(255); + ALTER TABLE %s ADD COLUMN IF NOT EXISTS session_id VARCHAR(255); + ALTER TABLE %s ADD COLUMN IF NOT EXISTS turn_index INTEGER DEFAULT 0; + ALTER TABLE %s ADD COLUMN IF NOT EXISTS previous_model VARCHAR(255); + ALTER TABLE %s ADD COLUMN IF NOT EXISTS cache_warmth_estimate DOUBLE PRECISION DEFAULT 0; ALTER TABLE %s ADD COLUMN IF NOT EXISTS projections JSONB; ALTER TABLE %s ADD COLUMN IF NOT EXISTS projection_scores JSONB; ALTER TABLE %s ADD COLUMN IF NOT EXISTS signal_confidences JSONB; @@ -130,12 +142,45 @@ func (p *PostgresStore) createTable(ctx context.Context) error { CREATE INDEX IF NOT EXISTS idx_%s_timestamp ON %s (timestamp DESC); CREATE INDEX IF NOT EXISTS idx_%s_created_at ON %s (created_at); CREATE INDEX IF NOT EXISTS idx_%s_request_id ON %s (request_id); + CREATE INDEX IF NOT EXISTS idx_%s_session_timestamp ON %s (session_id, timestamp DESC); CREATE INDEX IF NOT EXISTS idx_%s_decision_timestamp ON %s (decision, timestamp DESC); CREATE INDEX IF NOT EXISTS idx_%s_selected_model_timestamp ON %s (selected_model, timestamp DESC); - `, p.tableName, - p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, - p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, - p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName, p.tableName) + `, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + p.tableName, + ) _, err := p.db.ExecContext(ctx, query) return err @@ -167,7 +212,8 @@ func (p *PostgresStore) Add(ctx context.Context, record Record) (string, error) query := fmt.Sprintf(` INSERT INTO %s ( id, timestamp, request_id, decision, decision_tier, decision_priority, category, - original_model, selected_model, reasoning_mode, + original_model, selected_model, reasoning_mode, confidence_score, selection_method, + session_id, turn_index, previous_model, cache_warmth_estimate, signals, projections, projection_scores, signal_confidences, signal_values, tool_trace, request_body, response_body, response_status, from_cache, streaming, request_body_truncated, response_body_truncated, @@ -176,7 +222,7 @@ func (p *PostgresStore) Add(ctx context.Context, record Record) (string, error) hallucination_enabled, hallucination_detected, hallucination_confidence, hallucination_spans, prompt_tokens, completion_tokens, total_tokens, actual_cost, baseline_cost, cost_savings, currency, baseline_model - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48) `, p.tableName) fn := func() error { diff --git a/src/semantic-router/pkg/routerreplay/store/postgres_record_row.go b/src/semantic-router/pkg/routerreplay/store/postgres_record_row.go index 593e297fa7..dad51ddcea 100644 --- a/src/semantic-router/pkg/routerreplay/store/postgres_record_row.go +++ b/src/semantic-router/pkg/routerreplay/store/postgres_record_row.go @@ -10,7 +10,8 @@ import ( const postgresRecordSelectColumns = ` id, timestamp, request_id, decision, decision_tier, decision_priority, category, - original_model, selected_model, reasoning_mode, + original_model, selected_model, reasoning_mode, confidence_score, selection_method, + session_id, turn_index, previous_model, cache_warmth_estimate, signals, projections, projection_scores, signal_confidences, signal_values, tool_trace, request_body, response_body, response_status, from_cache, streaming, request_body_truncated, response_body_truncated, @@ -128,6 +129,12 @@ func (record postgresInsertRecord) args() []interface{} { record.record.OriginalModel, record.record.SelectedModel, record.record.ReasoningMode, + record.record.ConfidenceScore, + record.record.SelectionMethod, + record.record.SessionID, + record.record.TurnIndex, + record.record.PreviousModel, + record.record.CacheWarmthEstimate, record.signalsJSON, record.projectionsJSON, record.projectionScoresJSON, @@ -202,6 +209,12 @@ func (row *postgresRecordRow) scanDestinations() []interface{} { &row.record.OriginalModel, &row.record.SelectedModel, &row.record.ReasoningMode, + &row.record.ConfidenceScore, + &row.record.SelectionMethod, + &row.record.SessionID, + &row.record.TurnIndex, + &row.record.PreviousModel, + &row.record.CacheWarmthEstimate, &row.signalsJSON, &row.projectionsJSON, &row.projectionScoresJSON, diff --git a/src/semantic-router/pkg/routerreplay/store/store.go b/src/semantic-router/pkg/routerreplay/store/store.go index 5ff275f0c0..cc16cf945e 100644 --- a/src/semantic-router/pkg/routerreplay/store/store.go +++ b/src/semantic-router/pkg/routerreplay/store/store.go @@ -22,6 +22,7 @@ type Signal struct { Structure []string `json:"structure,omitempty"` Complexity []string `json:"complexity,omitempty"` Modality []string `json:"modality,omitempty"` + Session []string `json:"session,omitempty"` Authz []string `json:"authz,omitempty"` Jailbreak []string `json:"jailbreak,omitempty"` PII []string `json:"pii,omitempty"` @@ -73,6 +74,10 @@ type Record struct { ReasoningMode string `json:"reasoning_mode,omitempty"` ConfidenceScore float64 `json:"confidence_score,omitempty"` SelectionMethod string `json:"selection_method,omitempty"` + SessionID string `json:"session_id,omitempty"` + TurnIndex int `json:"turn_index,omitempty"` + PreviousModel string `json:"previous_model,omitempty"` + CacheWarmthEstimate float64 `json:"cache_warmth_estimate,omitempty"` Signals Signal `json:"signals"` Projections []string `json:"projections,omitempty"` ProjectionScores map[string]float64 `json:"projection_scores,omitempty"` @@ -223,6 +228,7 @@ func cloneSignal(signal Signal) Signal { Structure: cloneStringSlice(signal.Structure), Complexity: cloneStringSlice(signal.Complexity), Modality: cloneStringSlice(signal.Modality), + Session: cloneStringSlice(signal.Session), Authz: cloneStringSlice(signal.Authz), Jailbreak: cloneStringSlice(signal.Jailbreak), PII: cloneStringSlice(signal.PII), diff --git a/src/semantic-router/pkg/selection/factory.go b/src/semantic-router/pkg/selection/factory.go index ae11077cc8..9515c01d19 100644 --- a/src/semantic-router/pkg/selection/factory.go +++ b/src/semantic-router/pkg/selection/factory.go @@ -45,6 +45,9 @@ type ModelSelectionConfig struct { // Hybrid configuration (used when method is "hybrid") Hybrid *HybridConfig `yaml:"hybrid,omitempty"` + // SessionAware configuration (used when method is "session_aware") + SessionAware *SessionAwareConfig `yaml:"session_aware,omitempty"` + // ML configuration (used for knn, kmeans, svm methods) ML *MLSelectorConfig `yaml:"ml,omitempty"` @@ -155,6 +158,13 @@ func (f *Factory) Create() Selector { } selector = hybridSelector + case MethodSessionAware: + sessionAwareSelector := NewSessionAwareSelector(f.cfg.SessionAware) + if f.lookupTable != nil { + sessionAwareSelector.SetLookupTable(f.lookupTable) + } + selector = sessionAwareSelector + case MethodGMTRouter: gmtRouterSelector := NewGMTRouterSelector(f.cfg.GMTRouter) if f.modelConfig != nil { @@ -255,6 +265,17 @@ func (f *Factory) CreateAll() *Registry { } registry.Register(MethodHybrid, hybridSelector) + // Create SessionAware selector + sessionAwareCfg := f.cfg.SessionAware + if sessionAwareCfg == nil { + sessionAwareCfg = DefaultSessionAwareConfig() + } + sessionAwareSelector := NewSessionAwareSelector(sessionAwareCfg) + if f.lookupTable != nil { + sessionAwareSelector.SetLookupTable(f.lookupTable) + } + registry.Register(MethodSessionAware, sessionAwareSelector) + // Create ML-based selectors (KNN, KMeans, SVM) mlCfg := f.cfg.ML if mlCfg == nil { diff --git a/src/semantic-router/pkg/selection/lookuptable/builder.go b/src/semantic-router/pkg/selection/lookuptable/builder.go index 1afc5e3bbe..75a1065843 100644 --- a/src/semantic-router/pkg/selection/lookuptable/builder.go +++ b/src/semantic-router/pkg/selection/lookuptable/builder.go @@ -155,27 +155,33 @@ func (b *Builder) deriveQualityGaps(records []store.Record, batch map[Key]Entry, } } -// groupIntoSessions groups records into pseudo-sessions using Decision name and -// a time-window heuristic: records with the same Decision are sorted by -// timestamp, then split into sessions whenever consecutive records are more than -// sessionWindowDuration apart. This is more semantically stable than a RequestID -// prefix because it does not depend on external ID conventions. +// groupIntoSessions groups records into sessions using persisted SessionID when +// available, and falls back to Decision+time-window heuristics for older replay +// records that lack session metadata. func groupIntoSessions(records []store.Record) [][]store.Record { - // Group by Decision first. + bySession := make(map[string][]store.Record) byDecision := make(map[string][]store.Record) for _, r := range records { + if r.SessionID != "" { + bySession[r.SessionID] = append(bySession[r.SessionID], r) + continue + } byDecision[r.Decision] = append(byDecision[r.Decision], r) } var sessions [][]store.Record + for _, recs := range bySession { + sort.Slice(recs, func(i, j int) bool { + return recs[i].Timestamp.Before(recs[j].Timestamp) + }) + sessions = append(sessions, recs) + } + for _, recs := range byDecision { - // Sort chronologically within each Decision group. sort.Slice(recs, func(i, j int) bool { return recs[i].Timestamp.Before(recs[j].Timestamp) }) - // Split into windows: a new session starts when the gap between - // consecutive records exceeds sessionWindowDuration. start := 0 for i := 1; i < len(recs); i++ { if recs[i].Timestamp.Sub(recs[i-1].Timestamp) > sessionWindowDuration { diff --git a/src/semantic-router/pkg/selection/selector.go b/src/semantic-router/pkg/selection/selector.go index 61ac55e208..33fa59684d 100644 --- a/src/semantic-router/pkg/selection/selector.go +++ b/src/semantic-router/pkg/selection/selector.go @@ -52,6 +52,10 @@ const ( // Allows blending Elo, embedding similarity, and cost considerations MethodHybrid SelectionMethod = "hybrid" + // MethodSessionAware uses runtime session facts and replay-derived priors to + // decide whether to stay on the current model or switch mid-session. + MethodSessionAware SelectionMethod = "session_aware" + // MethodStatic uses static scores from configuration (default behavior) MethodStatic SelectionMethod = "static" @@ -172,6 +176,30 @@ type SelectionContext struct { // Used to track within-session model performance SessionID string + // TurnIndex is the zero-based count of prior user turns in the active session. + TurnIndex int + + // PreviousModel is the model used on the immediately preceding turn. + PreviousModel string + + // CurrentModel is the model the session is currently anchored to for stay-vs-switch decisions. + CurrentModel string + + // OriginalModel is the model requested before routing rewrites were applied. + OriginalModel string + + // CacheWarmthEstimate is a replay/runtime-derived [0,1] estimate of cache warmth for the current model. + CacheWarmthEstimate float64 + + // RemainingTurnsEstimate is the replay-derived estimate of how many turns remain after this request. + RemainingTurnsEstimate float64 + + // QualityGapByCandidate stores replay-derived quality-gap estimates keyed by candidate model. + QualityGapByCandidate map[string]float64 + + // HandoffPenaltyByCandidate stores replay-derived switch-penalty estimates keyed by candidate model. + HandoffPenaltyByCandidate map[string]float64 + // LatencyAwareTPOTPercentile is the configured TPOT percentile (1-100) for latency_aware selection LatencyAwareTPOTPercentile int diff --git a/src/semantic-router/pkg/selection/session_aware.go b/src/semantic-router/pkg/selection/session_aware.go new file mode 100644 index 0000000000..40a463799d --- /dev/null +++ b/src/semantic-router/pkg/selection/session_aware.go @@ -0,0 +1,315 @@ +package selection + +import ( + "context" + "fmt" + "math" + "strings" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection/lookuptable" +) + +// SessionAwareConfig configures stay-versus-switch routing for multi-turn sessions. +type SessionAwareConfig struct { + FallbackMethod string `yaml:"fallback_method"` + MinTurnsBeforeSwitch int `yaml:"min_turns_before_switch"` + StayBias float64 `yaml:"stay_bias"` + QualityGapMultiplier float64 `yaml:"quality_gap_multiplier"` + HandoffPenaltyWeight float64 `yaml:"handoff_penalty_weight"` + RemainingTurnWeight float64 `yaml:"remaining_turn_weight"` +} + +// DefaultSessionAwareConfig returns conservative defaults that prefer staying on +// the current session model unless a candidate has a materially better replay-backed score. +func DefaultSessionAwareConfig() *SessionAwareConfig { + return &SessionAwareConfig{ + FallbackMethod: string(MethodStatic), + MinTurnsBeforeSwitch: 1, + StayBias: 0.25, + QualityGapMultiplier: 1.0, + HandoffPenaltyWeight: 1.0, + RemainingTurnWeight: 0.15, + } +} + +// SessionAwareSelector selects models using runtime session facts and lookup-table priors. +type SessionAwareSelector struct { + config *SessionAwareConfig + lookupTable lookuptable.LookupTable +} + +// NewSessionAwareSelector creates a new session-aware selector. +func NewSessionAwareSelector(cfg *SessionAwareConfig) *SessionAwareSelector { + if cfg == nil { + cfg = DefaultSessionAwareConfig() + } + return &SessionAwareSelector{config: cfg} +} + +// Method returns the selection method type. +func (s *SessionAwareSelector) Method() SelectionMethod { + return MethodSessionAware +} + +// SetLookupTable attaches replay-derived lookup-table priors. +func (s *SessionAwareSelector) SetLookupTable(lt lookuptable.LookupTable) { + s.lookupTable = lt +} + +// UpdateFeedback is currently a no-op because the selector consumes replay-derived priors. +func (s *SessionAwareSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + _ = ctx + _ = feedback + return nil +} + +// Select chooses the best candidate for a continuing conversation. +func (s *SessionAwareSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + _ = ctx + if selCtx == nil { + return nil, fmt.Errorf("selection context is required") + } + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + if len(selCtx.CandidateModels) == 1 { + candidate := selCtx.CandidateModels[0] + return &SelectionResult{ + SelectedModel: candidate.Model, + LoRAName: candidate.LoRAName, + Score: 1.0, + Confidence: 1.0, + Method: MethodSessionAware, + Reasoning: "Single candidate available", + AllScores: map[string]float64{ + candidate.Model: 1.0, + }, + }, nil + } + + currentModel := sessionCurrentModel(selCtx) + if selCtx.SessionID == "" || currentModel == "" { + return s.fallbackSelect(ctx, selCtx, "Session context unavailable") + } + + if selCtx.TurnIndex < s.config.MinTurnsBeforeSwitch { + if candidate := findCandidateByCurrentModel(selCtx.CandidateModels, currentModel); candidate != nil { + return &SelectionResult{ + SelectedModel: candidate.Model, + LoRAName: candidate.LoRAName, + Score: 1.0, + Confidence: 1.0, + Method: MethodSessionAware, + Reasoning: fmt.Sprintf( + "Staying on %s because turn_index=%d is below min_turns_before_switch=%d", + currentModel, + selCtx.TurnIndex, + s.config.MinTurnsBeforeSwitch, + ), + AllScores: map[string]float64{ + candidate.Model: 1.0, + }, + }, nil + } + } + + allScores := make(map[string]float64, len(selCtx.CandidateModels)) + bestIdx := 0 + bestScore := math.Inf(-1) + secondBest := math.Inf(-1) + bestReason := "" + + for i := range selCtx.CandidateModels { + candidate := &selCtx.CandidateModels[i] + score, reason := s.scoreCandidate(selCtx, currentModel, candidate) + allScores[candidate.Model] = score + if score > bestScore { + secondBest = bestScore + bestScore = score + bestIdx = i + bestReason = reason + } else if score > secondBest { + secondBest = score + } + } + + bestCandidate := selCtx.CandidateModels[bestIdx] + confidence := 1.0 + if secondBest > math.Inf(-1) { + denominator := math.Max(math.Abs(bestScore), math.Abs(secondBest)) + if denominator < 1.0 { + denominator = 1.0 + } + confidence = clampScore((bestScore-secondBest)/denominator, 0, 1) + } + + logging.Infof("[SessionAwareSelector] Selected %s for session=%s turn=%d current=%s (score=%.4f, confidence=%.2f)", + bestCandidate.Model, selCtx.SessionID, selCtx.TurnIndex, currentModel, bestScore, confidence) + + return &SelectionResult{ + SelectedModel: bestCandidate.Model, + LoRAName: bestCandidate.LoRAName, + Score: bestScore, + Confidence: confidence, + Method: MethodSessionAware, + Reasoning: bestReason, + AllScores: allScores, + }, nil +} + +func (s *SessionAwareSelector) fallbackSelect(ctx context.Context, selCtx *SelectionContext, reason string) (*SelectionResult, error) { + method := strings.TrimSpace(s.config.FallbackMethod) + if method == "" { + method = string(MethodStatic) + } + if SelectionMethod(method) == MethodSessionAware { + method = string(MethodStatic) + } + result, err := Select(ctx, SelectionMethod(method), selCtx) + if err != nil { + return nil, err + } + if result == nil { + return nil, fmt.Errorf("fallback selector %q returned no result", method) + } + result.Reasoning = fmt.Sprintf("%s; fallback=%s: %s", reason, method, result.Reasoning) + return result, nil +} + +func (s *SessionAwareSelector) scoreCandidate(selCtx *SelectionContext, currentModel string, candidate *config.ModelRef) (float64, string) { + qualityGap := s.qualityGap(selCtx, currentModel, candidate.Model) + handoffPenalty := s.handoffPenalty(selCtx, currentModel, candidate.Model) + remainingTurns := s.remainingTurns(selCtx) + + score := s.config.QualityGapMultiplier * qualityGap + reasons := []string{fmt.Sprintf("quality_gap=%.4f", qualityGap)} + + if matchesCurrentModel(*candidate, currentModel) { + stayScore := s.config.StayBias + s.config.RemainingTurnWeight*remainingTurns + clampScore(selCtx.CacheWarmthEstimate, 0, 1) + score += stayScore + reasons = append(reasons, + fmt.Sprintf("stay_bias=%.4f", s.config.StayBias), + fmt.Sprintf("remaining_turns=%.4f", remainingTurns), + fmt.Sprintf("cache_warmth=%.4f", clampScore(selCtx.CacheWarmthEstimate, 0, 1)), + ) + } else { + penalty := s.config.HandoffPenaltyWeight * handoffPenalty + score -= penalty + reasons = append(reasons, fmt.Sprintf("handoff_penalty=-%.4f", penalty)) + } + + return score, strings.Join(reasons, ", ") +} + +func (s *SessionAwareSelector) qualityGap(selCtx *SelectionContext, currentModel, candidateModel string) float64 { + if candidateModel == "" || currentModel == "" || currentModel == candidateModel { + return 0 + } + if selCtx.QualityGapByCandidate != nil { + if value, ok := selCtx.QualityGapByCandidate[candidateModel]; ok { + return value + } + } + if s.lookupTable == nil { + return 0 + } + taskFamily := sessionTaskFamily(selCtx) + if taskFamily == "" { + return 0 + } + value, ok := s.lookupTable.QualityGap(taskFamily, currentModel, candidateModel) + if !ok { + return 0 + } + return value +} + +func (s *SessionAwareSelector) handoffPenalty(selCtx *SelectionContext, currentModel, candidateModel string) float64 { + if candidateModel == "" || currentModel == "" || currentModel == candidateModel { + return 0 + } + if selCtx.HandoffPenaltyByCandidate != nil { + if value, ok := selCtx.HandoffPenaltyByCandidate[candidateModel]; ok { + return value + } + } + if s.lookupTable == nil { + return 0 + } + value, ok := s.lookupTable.HandoffPenalty(currentModel, candidateModel) + if !ok { + return 0 + } + return value +} + +func (s *SessionAwareSelector) remainingTurns(selCtx *SelectionContext) float64 { + if selCtx.RemainingTurnsEstimate > 0 { + return selCtx.RemainingTurnsEstimate + } + if s.lookupTable == nil { + return 0 + } + taskFamily := sessionTaskFamily(selCtx) + if taskFamily == "" { + return 0 + } + prior, ok := s.lookupTable.RemainingTurnPrior(taskFamily) + if !ok { + return 0 + } + remaining := prior - float64(selCtx.TurnIndex) + if remaining < 0 { + return 0 + } + return remaining +} + +func sessionTaskFamily(selCtx *SelectionContext) string { + if selCtx == nil { + return "" + } + if strings.TrimSpace(selCtx.CategoryName) != "" { + return strings.TrimSpace(selCtx.CategoryName) + } + return strings.TrimSpace(selCtx.DecisionName) +} + +func sessionCurrentModel(selCtx *SelectionContext) string { + if selCtx == nil { + return "" + } + for _, candidate := range []string{selCtx.CurrentModel, selCtx.PreviousModel, selCtx.OriginalModel} { + if trimmed := strings.TrimSpace(candidate); trimmed != "" { + return trimmed + } + } + return "" +} + +func findCandidateByCurrentModel(candidates []config.ModelRef, currentModel string) *config.ModelRef { + for i := range candidates { + if matchesCurrentModel(candidates[i], currentModel) { + return &candidates[i] + } + } + return nil +} + +func matchesCurrentModel(candidate config.ModelRef, currentModel string) bool { + return strings.TrimSpace(candidate.Model) == strings.TrimSpace(currentModel) || + (strings.TrimSpace(candidate.LoRAName) != "" && strings.TrimSpace(candidate.LoRAName) == strings.TrimSpace(currentModel)) +} + +func clampScore(value, minValue, maxValue float64) float64 { + if value < minValue { + return minValue + } + if value > maxValue { + return maxValue + } + return value +} diff --git a/src/semantic-router/pkg/selection/tier_declarations.go b/src/semantic-router/pkg/selection/tier_declarations.go index b7becd21b5..66ae5e99b8 100644 --- a/src/semantic-router/pkg/selection/tier_declarations.go +++ b/src/semantic-router/pkg/selection/tier_declarations.go @@ -58,6 +58,16 @@ func (h *HybridSelector) ExternalDependencies() []Dependency { return []Dependency{} } +// Tier returns the production readiness tier +func (s *SessionAwareSelector) Tier() AlgorithmTier { + return TierSupported +} + +// ExternalDependencies returns external dependencies (none for session-aware) +func (s *SessionAwareSelector) ExternalDependencies() []Dependency { + return []Dependency{} +} + // --- Experimental-tier algorithms --- // Tier returns the production readiness tier diff --git a/src/vllm-sr/cli/algorithms.py b/src/vllm-sr/cli/algorithms.py index 0d5ecdc3e8..ea94c968cc 100644 --- a/src/vllm-sr/cli/algorithms.py +++ b/src/vllm-sr/cli/algorithms.py @@ -349,6 +349,7 @@ class AlgorithmConfig(BaseModel): router_dc: RouterDCSelectionConfig | None = None automix: AutoMixSelectionConfig | None = None hybrid: HybridSelectionConfig | None = None + session_aware: SessionAwareSelectionConfig | None = None # RL-driven selection algorithms (from PR #1196, issue #994) thompson: ThompsonSamplingConfig | None = None diff --git a/src/vllm-sr/cli/config_contract.py b/src/vllm-sr/cli/config_contract.py index ac7f6fcfa5..15867e823f 100644 --- a/src/vllm-sr/cli/config_contract.py +++ b/src/vllm-sr/cli/config_contract.py @@ -85,6 +85,7 @@ class SignalFamilySpec: ("easy", "medium", "hard"), ), SignalFamilySpec("modality", "modality", "modality", "modality_rules"), + SignalFamilySpec("session", "session", "session"), SignalFamilySpec("role_bindings", "role_bindings", "authz", "role_bindings"), SignalFamilySpec("jailbreak", "jailbreak", "jailbreak", "jailbreak"), SignalFamilySpec("pii", "pii", "pii", "pii"), diff --git a/src/vllm-sr/cli/models.py b/src/vllm-sr/cli/models.py index 82f777729d..2bdd4b8adc 100644 --- a/src/vllm-sr/cli/models.py +++ b/src/vllm-sr/cli/models.py @@ -346,6 +346,7 @@ class Signals(BaseModel): structure: Optional[List[StructureRule]] = [] complexity: Optional[List[ComplexityRule]] = [] modality: Optional[List[ModalityRule]] = [] + session: Optional[List[SessionRule]] = [] role_bindings: Optional[List[RoleBindingRule]] = [] jailbreak: Optional[List[JailbreakRule]] = [] pii: Optional[List[PIIRule]] = [] diff --git a/src/vllm-sr/cli/validator.py b/src/vllm-sr/cli/validator.py index 2c4dcc3ee1..a15c97488c 100644 --- a/src/vllm-sr/cli/validator.py +++ b/src/vllm-sr/cli/validator.py @@ -77,6 +77,12 @@ def _is_latency_aware_algorithm(decision) -> bool: return (decision.algorithm.type or "").strip().lower() == "latency_aware" +def _is_session_aware_algorithm(decision) -> bool: + if not decision.algorithm: + return False + return (decision.algorithm.type or "").strip().lower() == "session_aware" + + def validate_latency_compatibility(config: UserConfig) -> List[ValidationError]: errors = [] has_legacy_conditions = any( @@ -475,6 +481,7 @@ def validate_algorithm_configurations(config: UserConfig) -> List[ValidationErro "router_dc", "automix", "hybrid", + "session_aware", "latency_aware", "thompson", "gmtrouter", @@ -560,6 +567,7 @@ def validate_user_config(config: UserConfig) -> List[ValidationError]: errors.extend(validate_latency_compatibility(config)) errors.extend(validate_algorithm_one_of(config)) errors.extend(validate_latency_aware_algorithm_config(config)) + errors.extend(validate_session_aware_algorithm_config(config)) # Validate domain references errors.extend(validate_domain_references(config)) diff --git a/tools/agent/repo-manifest.yaml b/tools/agent/repo-manifest.yaml index 160e172163..148d515b7b 100644 --- a/tools/agent/repo-manifest.yaml +++ b/tools/agent/repo-manifest.yaml @@ -150,6 +150,7 @@ docs: - docs/agent/plans/pl-0027-router-runtime-composition-root-convergence-loop.md - docs/agent/plans/pl-0028-knowledge-base-seed-and-steady-state-convergence-loop.md - docs/agent/plans/pl-0029-control-plane-contract-boundary-ratchet.md + - docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md - docs/agent/state-taxonomy-and-inventory.md - docs/agent/tech-debt-register.md - docs/agent/tech-debt/README.md @@ -378,6 +379,9 @@ doc_governance: - path: docs/agent/plans/pl-0029-control-plane-contract-boundary-ratchet.md steward: agent-contract freshness: update at the start and end of each control-plane contract boundary ratchet loop and after each completed loop task + - path: docs/agent/plans/pl-0030-session-aware-routing-convergence-loop.md + steward: router-core + freshness: update at the start and end of each session-aware routing convergence loop and after each completed loop task - path: docs/agent/state-taxonomy-and-inventory.md steward: agent-contract freshness: update when the runtime or dashboard state inventory, durability taxonomy, or persistence guidance changes materially diff --git a/website/docs/tutorials/algorithm/overview.md b/website/docs/tutorials/algorithm/overview.md index bf2d3ab12f..4a1f154117 100644 --- a/website/docs/tutorials/algorithm/overview.md +++ b/website/docs/tutorials/algorithm/overview.md @@ -120,6 +120,7 @@ flowchart TD - [KNN](./selection/knn) - [Latency Aware](./selection/latency-aware) - [MLP](./selection/mlp) +- [Session Aware](./selection/session-aware) - [RL Driven](./selection/rl-driven) - [Router DC](./selection/router-dc) - [Static](./selection/static) diff --git a/website/docs/tutorials/algorithm/selection/session-aware.md b/website/docs/tutorials/algorithm/selection/session-aware.md new file mode 100644 index 0000000000..d1684e03aa --- /dev/null +++ b/website/docs/tutorials/algorithm/selection/session-aware.md @@ -0,0 +1,85 @@ +# Session Aware + +## Overview + +`session_aware` is a selection algorithm for multi-turn conversations. It prefers staying on the current model unless runtime session facts and replay-backed lookup-table priors indicate that switching to another candidate is worth the handoff cost. + +It aligns to `config/algorithm/selection/session-aware.yaml`. + +## Key Advantages + +- Makes mid-session stay-versus-switch behavior explicit in config. +- Uses replay-backed priors instead of hard-coding one-off heuristics per route. +- Preserves a conservative fallback path when session context is missing. +- Keeps multi-turn routing auditable at the same layer as the rest of `decision.algorithm`. + +## What Problem Does It Solve? + +Plain single-turn selectors only score the current request. In a multi-turn session, switching models also has a continuity cost: cache warmth, handoff penalty, and the loss of the current model's conversational context. + +`session_aware` solves that by combining runtime-derived session facts with replay-backed priors so the router can decide whether to stay on the current model or switch to a better candidate. + +## When to Use + +Use `session_aware` when: + +- one decision serves multi-turn traffic +- the current model choice should depend on prior turns +- switching models has a meaningful continuity or cache cost +- you want a conservative fallback such as `static` or `hybrid` + +## Configuration + +```yaml +algorithm: + type: session_aware + session_aware: + fallback_method: hybrid + min_turns_before_switch: 2 + stay_bias: 0.3 + quality_gap_multiplier: 1.15 + handoff_penalty_weight: 0.9 + remaining_turn_weight: 0.45 +``` + +### Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `fallback_method` | string | `static` | Selector used when session context is unavailable or insufficient | +| `min_turns_before_switch` | int | `1` | Minimum turn depth before the selector considers switching | +| `stay_bias` | float | `0.25` | Baseline preference for keeping the current session model | +| `quality_gap_multiplier` | float | `1.0` | Weight applied to replay-backed quality-gap estimates | +| `handoff_penalty_weight` | float | `1.0` | Weight applied to replay-derived switch penalties | +| `remaining_turn_weight` | float | `0.15` | Extra value for continuity when more turns are expected | + +## Select Flow + +```mermaid +flowchart TD + A[Decision matched] --> B[algorithm.type = session_aware] + B --> C{Session context present?} + C -- No --> D[Fallback to fallback_method] + C -- Yes --> E[Read turn index and previous model] + E --> F[Read replay-backed lookup priors] + F --> G[Score stay vs switch] + G --> H{Switch beats stay?} + H -- No --> I[Keep current model] + H -- Yes --> J[Select higher-value candidate] +``` + +## Runtime Inputs + +`session_aware` depends on runtime-derived facts that the router already injects into the selection context: + +- current `session_id` +- `turn_index` +- `previous_model` +- replay-derived quality-gap / handoff-penalty priors +- expected remaining-turn value inferred by the selector + +## Known Limitations + +- It relies on replay-backed priors, so cold-start data is less informative. +- It is intentionally narrower than the earlier `SESSION_STATE` design; there is no general-purpose DSL state machine here. +- It optimizes stay-versus-switch behavior inside one matched decision, not full fleet-wide RL policy tuning. diff --git a/website/docs/tutorials/signal/heuristic/session.md b/website/docs/tutorials/signal/heuristic/session.md new file mode 100644 index 0000000000..22b88c7425 --- /dev/null +++ b/website/docs/tutorials/signal/heuristic/session.md @@ -0,0 +1,80 @@ +# Session Signal + +## Overview + +`session` exposes runtime-derived multi-turn facts as named routing signals under `routing.signals.session`. + +It maps to `config/signal/session/` and is declared as part of the normal signal catalog, so decisions and projections can reference session facts without reviving the removed `SESSION_STATE` public surface. + +## Key Advantages + +- Lets decisions reference multi-turn state through named reusable signals. +- Keeps session-aware routing inside the same config graph as other signals. +- Supports numeric predicates over runtime facts such as turn depth or cache warmth. +- Allows model-specific continuity rules using `previous_model` and `candidate_model`. + +## What Problem Does It Solve? + +A multi-turn route often depends on facts that are not visible in the raw prompt: whether the request belongs to an existing session, which model served the last turn, and whether switching models is likely to be expensive. + +`session` solves that by exposing those runtime-derived facts as normal routing signals, so routes can compose them with domain, complexity, safety, or projection logic. + +## When to Use + +Use `session` when: + +- a route behaves differently on the first turn vs. continuation turns +- you want continuity rules tied to the previous or candidate model +- replay-backed routing logic needs explicit named signal references +- decisions should combine session facts with domain or projection conditions + +## Configuration + +Source fragment family: `config/signal/session/` + +```yaml +routing: + signals: + session: + - name: session_present + description: Requests that belong to an existing multi-turn conversation. + fact: session_present + predicate: + gte: 1 + - name: warm_cache_continuation + description: Prefer staying on the warmed model when the same conversation continues. + fact: cache_warmth + previous_model: qwen3-8b + predicate: + gte: 0.6 + - name: expensive_handoff + description: Detect costly mid-session upgrades into the premium coding model. + fact: handoff_penalty + intent_or_domain: computer science + previous_model: qwen3-8b + candidate_model: qwen3-32b + predicate: + gte: 0.15 +``` + +### Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `name` | yes | Signal name referenced from decisions and projections | +| `fact` | yes | Runtime-derived fact name injected by the router | +| `predicate` | no | Numeric threshold predicate over the fact value | +| `intent_or_domain` | no | Optional domain/task-family guard | +| `previous_model` | no | Only match when the previous turn used this model | +| `candidate_model` | no | Only match when evaluating this candidate model | +| `description` | no | Human-readable explanation | + +## Runtime Behavior + +The router computes session signal values from request/session context and stores them alongside the normal signal confidence/value map. Decisions can then reference them with `type: session` just like other signal families. + +## Known Limitations + +- `session` is a routing surface, not a general-purpose persisted DSL state machine. +- Facts are runtime-derived and router-owned; unsupported fact names will fail config validation. +- Session signals help describe continuity conditions, but the actual stay-versus-switch scoring lives in `algorithm.session_aware`. diff --git a/website/docs/tutorials/signal/overview.md b/website/docs/tutorials/signal/overview.md index 2b96ac41e0..a1ee4e7469 100644 --- a/website/docs/tutorials/signal/overview.md +++ b/website/docs/tutorials/signal/overview.md @@ -89,6 +89,7 @@ These signals route from explicit rules, request form, or lightweight detectors | `context` | `config/signal/context/` | route by effective token-window needs | [Context](./heuristic/context) | | `keyword` | `config/signal/keyword/` | route from lexical or BM25-style matches | [Keyword](./heuristic/keyword) | | `language` | `config/signal/language/` | route by detected request language | [Language](./heuristic/language) | +| `session` | `config/signal/session/` | route from runtime-derived multi-turn session facts | [Session](./heuristic/session) | | `structure` | `config/signal/structure/` | route from request shape such as question counts or ordered workflow markers | [Structure](./heuristic/structure) | ### Learned Signals From b8b91c8bd14feb45417398baa956adefc0521584 Mon Sep 17 00:00:00 2001 From: xunzhuo Date: Wed, 15 Apr 2026 23:56:44 +0800 Subject: [PATCH 2/2] fix: lint Signed-off-by: xunzhuo --- .../src/pages/ConfigPageSignalsSection.tsx | 1 - .../frontend/src/pages/topology/constants.ts | 3 +++ dashboard/frontend/src/pages/topology/types.ts | 12 ++++++++++++ src/semantic-router/pkg/config/config.go | 9 ++++----- src/semantic-router/pkg/extproc/router.go | 4 ++-- src/vllm-sr/cli/algorithms.py | 15 +++++++++++++++ src/vllm-sr/cli/models.py | 12 ++++++++++++ src/vllm-sr/cli/validator.py | 3 +++ 8 files changed, 51 insertions(+), 8 deletions(-) diff --git a/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx b/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx index 2dc607dbbe..02fbe64215 100644 --- a/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx +++ b/dashboard/frontend/src/pages/ConfigPageSignalsSection.tsx @@ -18,7 +18,6 @@ import type { KeywordSignal, LanguageSignal, ModalitySignal, - SessionSignal, PIISignal, PreferenceSignal, ReaskSignal, diff --git a/dashboard/frontend/src/pages/topology/constants.ts b/dashboard/frontend/src/pages/topology/constants.ts index 4245d0d6e6..24bc2c7cd6 100644 --- a/dashboard/frontend/src/pages/topology/constants.ts +++ b/dashboard/frontend/src/pages/topology/constants.ts @@ -38,6 +38,7 @@ export const SIGNAL_COLORS: Record = { automix: 'AM', hybrid: 'HY', remom: 'RM', + session_aware: 'SES', latency_aware: 'LAT', } @@ -281,6 +283,7 @@ export const SIGNAL_TYPES: SignalType[] = [ 'structure', 'complexity', 'modality', + 'session', 'authz', 'jailbreak', 'pii', diff --git a/dashboard/frontend/src/pages/topology/types.ts b/dashboard/frontend/src/pages/topology/types.ts index a796c9ee8b..1d792596e2 100644 --- a/dashboard/frontend/src/pages/topology/types.ts +++ b/dashboard/frontend/src/pages/topology/types.ts @@ -16,6 +16,7 @@ export type SignalType = | 'structure' | 'complexity' | 'modality' + | 'session' | 'authz' | 'jailbreak' | 'pii' @@ -187,6 +188,7 @@ export type AlgorithmType = | 'automix' | 'hybrid' | 'remom' + | 'session_aware' | 'latency_aware' export interface AlgorithmConfig { @@ -194,6 +196,7 @@ export interface AlgorithmConfig { confidence?: ConfidenceAlgorithmConfig concurrent?: ConcurrentAlgorithmConfig latency_aware?: LatencyAwareAlgorithmConfig + session_aware?: SessionAwareAlgorithmConfig autoMix?: AutoMixConfig } @@ -620,6 +623,15 @@ export interface ConfigData { name: string description?: string }> + session?: Array<{ + name: string + description?: string + fact: string + predicate?: NumericPredicateConfig + intent_or_domain?: string + previous_model?: string + candidate_model?: string + }> role_bindings?: Array<{ name: string role: string diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 05795a5038..783ec15af0 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -184,11 +184,10 @@ type IntelligentRouting struct { Signals `yaml:",inline"` Projections Projections `yaml:"projections,omitempty"` Decisions []Decision `yaml:"decisions,omitempty"` - Strategy string `yaml:"strategy,omitempty"` - ModelSelection ModelSelectionConfig `yaml:"model_selection,omitempty"` - ReasoningConfig `yaml:",inline"` - } - + Strategy string `yaml:"strategy,omitempty"` + ModelSelection ModelSelectionConfig `yaml:"model_selection,omitempty"` + ReasoningConfig `yaml:",inline"` +} // BackendModels captures configured backend endpoints and model metadata. type BackendModels struct { diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index 92f59a5092..40efe0a01a 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -35,8 +35,8 @@ type OpenAIRouter struct { ReplayStoreShared bool // ModelSelector is the registry of advanced model selection algorithms // initialized from config.IntelligentRouting.ModelSelection. - ModelSelector *selection.Registry - LookupTable lookuptable.LookupTable + ModelSelector *selection.Registry + LookupTable lookuptable.LookupTable ReplayRecorders map[string]*routerreplay.Recorder MemoryStore memory.Store MemoryExtractor *memory.MemoryExtractor diff --git a/src/vllm-sr/cli/algorithms.py b/src/vllm-sr/cli/algorithms.py index ea94c968cc..759eea4d27 100644 --- a/src/vllm-sr/cli/algorithms.py +++ b/src/vllm-sr/cli/algorithms.py @@ -233,6 +233,21 @@ class HybridSelectionConfig(BaseModel): normalize_scores: bool | None = True +class SessionAwareSelectionConfig(BaseModel): + """Configuration for session-aware model selection. + + Balances stay-versus-switch decisions using runtime session facts and + replay-backed priors. + """ + + fallback_method: str | None = "static" + min_turns_before_switch: int | None = Field(default=1, ge=0) + stay_bias: float | None = Field(default=0.25, ge=0) + quality_gap_multiplier: float | None = Field(default=1.0, ge=0) + handoff_penalty_weight: float | None = Field(default=1.0, ge=0) + remaining_turn_weight: float | None = Field(default=0.15, ge=0) + + # ============================================================================= # RL-Driven Model Selection Algorithm Configs (from PR #1196 / Issue #994) # Reference papers: diff --git a/src/vllm-sr/cli/models.py b/src/vllm-sr/cli/models.py index 2bdd4b8adc..c54d03f80f 100644 --- a/src/vllm-sr/cli/models.py +++ b/src/vllm-sr/cli/models.py @@ -285,6 +285,18 @@ class ModalityRule(BaseModel): description: Optional[str] = None +class SessionRule(BaseModel): + """Session-derived routing signal configuration.""" + + name: str + description: Optional[str] = None + fact: str + predicate: Optional[NumericPredicate] = None + intent_or_domain: Optional[str] = None + previous_model: Optional[str] = None + candidate_model: Optional[str] = None + + class Subject(BaseModel): """RBAC subject (user or group) for role binding.""" diff --git a/src/vllm-sr/cli/validator.py b/src/vllm-sr/cli/validator.py index a15c97488c..f169379238 100644 --- a/src/vllm-sr/cli/validator.py +++ b/src/vllm-sr/cli/validator.py @@ -157,6 +157,7 @@ def validate_algorithm_one_of(config: UserConfig) -> List[ValidationError]: "concurrent": "concurrent", "remom": "remom", "latency_aware": "latency_aware", + "session_aware": "session_aware", } for decision in config.decisions: @@ -173,6 +174,8 @@ def validate_algorithm_one_of(config: UserConfig) -> List[ValidationError]: configured_blocks.append("remom") if algorithm.latency_aware is not None: configured_blocks.append("latency_aware") + if algorithm.session_aware is not None: + configured_blocks.append("session_aware") display_type = (algorithm.type or "").strip() or "" normalized_type = (algorithm.type or "").strip().lower()