Skip to content

Commit 77ea7f2

Browse files
committed
fix: honor explicit model provider routing
1 parent 8819ad7 commit 77ea7f2

7 files changed

Lines changed: 65 additions & 23 deletions

File tree

src/agent/agent.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ const OVERFLOW_KEEP_ROUNDS = 3;
3636
*/
3737
export class Agent {
3838
private readonly model: string;
39+
private readonly modelProvider?: string;
3940
private readonly maxIterations: number;
4041
private readonly tools: StructuredToolInterface[];
4142
private readonly toolMap: Map<string, StructuredToolInterface>;
@@ -53,6 +54,7 @@ export class Agent {
5354
concurrencyMap: Map<string, boolean>,
5455
) {
5556
this.model = config.model ?? DEFAULT_MODEL;
57+
this.modelProvider = config.modelProvider;
5658
this.maxIterations = config.maxIterations ?? DEFAULT_MAX_ITERATIONS;
5759
this.tools = tools;
5860
this.toolMap = new Map(tools.map(t => [t.name, t]));
@@ -160,7 +162,7 @@ export class Agent {
160162
}
161163

162164
const totalTime = Date.now() - ctx.startTime;
163-
const provider = resolveProvider(this.model).displayName;
165+
const provider = resolveProvider(this.model, this.modelProvider).displayName;
164166
yield {
165167
type: 'done',
166168
answer: `Error: ${formatUserFacingError(errorMessage, provider)}`,
@@ -300,6 +302,7 @@ export class Agent {
300302

301303
for await (const chunk of streamLlmWithMessages(messages, {
302304
model: this.model,
305+
modelProvider: this.modelProvider,
303306
tools: this.tools,
304307
signal: this.signal,
305308
})) {
@@ -345,6 +348,7 @@ export class Agent {
345348
): Promise<{ response: AIMessage; usage?: TokenUsage }> {
346349
const result = await callLlmWithMessages(messages, {
347350
model: this.model,
351+
modelProvider: this.modelProvider,
348352
tools: this.tools,
349353
signal: this.signal,
350354
});
@@ -541,7 +545,7 @@ export class Agent {
541545
: estimateTokens(messageState.messages.map(m =>
542546
typeof m.content === 'string' ? m.content : JSON.stringify(m.content),
543547
).join('\n'));
544-
const threshold = getAutoCompactThreshold(this.model);
548+
const threshold = getAutoCompactThreshold(this.model, this.modelProvider);
545549

546550
if (estimatedContextTokens <= threshold) {
547551
return;
@@ -560,6 +564,7 @@ export class Agent {
560564
yield { type: 'memory_flush', phase: 'start' };
561565
const flushResult = await runMemoryFlush({
562566
model: this.model,
567+
modelProvider: this.modelProvider,
563568
systemPrompt: this.systemPrompt,
564569
query,
565570
toolResults: fullToolResults,
@@ -583,6 +588,7 @@ export class Agent {
583588
try {
584589
const result = await compactContext({
585590
model: this.model,
591+
modelProvider: this.modelProvider,
586592
systemPrompt: this.systemPrompt,
587593
query,
588594
toolResults: fullToolResults,
@@ -611,7 +617,7 @@ export class Agent {
611617
success: true,
612618
preCompactTokens: estimatedContextTokens,
613619
postCompactTokens,
614-
compactionModel: resolveProvider(this.model).fastModel ?? this.model,
620+
compactionModel: resolveProvider(this.model, this.modelProvider).fastModel ?? this.model,
615621
};
616622

617623
return;

src/agent/compact.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ Continue working toward answering the query without asking the user any further
168168
export interface CompactContextParams {
169169
/** Main model name (used to resolve provider and fast model). */
170170
model: string;
171+
/** Explicit provider override from agent settings. */
172+
modelProvider?: string;
171173
/** System prompt for the compaction call. */
172174
systemPrompt: string;
173175
/** Original user query. */
@@ -192,10 +194,10 @@ export interface CompactResult {
192194
* Throws on failure — caller is responsible for fallback to clearing.
193195
*/
194196
export async function compactContext(params: CompactContextParams): Promise<CompactResult> {
195-
const { model, systemPrompt, query, toolResults, signal } = params;
197+
const { model, modelProvider, systemPrompt, query, toolResults, signal } = params;
196198

197199
// Resolve fast model for the current provider
198-
const provider = resolveProvider(model);
200+
const provider = resolveProvider(model, modelProvider);
199201
const fastModel = provider.fastModel ?? model;
200202

201203
// Build the compaction prompt
@@ -204,6 +206,7 @@ export async function compactContext(params: CompactContextParams): Promise<Comp
204206
// Call LLM with no tools bound — callLlm returns string in this case
205207
const result = await callLlm(prompt, {
206208
model: fastModel,
209+
modelProvider: provider.id,
207210
systemPrompt,
208211
signal,
209212
});

src/memory/flush.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export function shouldRunMemoryFlush(params: {
3838

3939
export async function runMemoryFlush(params: {
4040
model: string;
41+
modelProvider?: string;
4142
systemPrompt: string;
4243
query: string;
4344
toolResults: string;
@@ -55,6 +56,7 @@ ${MEMORY_FLUSH_PROMPT}
5556

5657
const result = await callLlm(prompt, {
5758
model: params.model,
59+
modelProvider: params.modelProvider,
5860
systemPrompt: params.systemPrompt,
5961
signal: params.signal,
6062
});

src/model/llm.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,18 @@ const DEFAULT_FACTORY: ModelFactory = (name, opts) =>
143143

144144
export function getChatModel(
145145
modelName: string = DEFAULT_MODEL,
146-
streaming: boolean = false
146+
streaming: boolean = false,
147+
providerOverride?: string,
147148
): BaseChatModel {
148149
const opts: ModelOpts = { streaming };
149-
const provider = resolveProvider(modelName);
150+
const provider = resolveProvider(modelName, providerOverride);
150151
const factory = MODEL_FACTORIES[provider.id] ?? DEFAULT_FACTORY;
151152
return factory(modelName, opts);
152153
}
153154

154155
interface CallLlmOptions {
155156
model?: string;
157+
modelProvider?: string;
156158
systemPrompt?: string;
157159
outputSchema?: z.ZodType<unknown>;
158160
tools?: StructuredToolInterface[];
@@ -213,10 +215,10 @@ function buildAnthropicMessages(systemPrompt: string, userPrompt: string) {
213215
}
214216

215217
export async function callLlm(prompt: string, options: CallLlmOptions = {}): Promise<LlmResult> {
216-
const { model = DEFAULT_MODEL, systemPrompt, outputSchema, tools, signal } = options;
218+
const { model = DEFAULT_MODEL, modelProvider, systemPrompt, outputSchema, tools, signal } = options;
217219
const finalSystemPrompt = systemPrompt || DEFAULT_SYSTEM_PROMPT;
218220

219-
const llm = getChatModel(model, false);
221+
const llm = getChatModel(model, false, modelProvider);
220222

221223
// eslint-disable-next-line @typescript-eslint/no-explicit-any
222224
let runnable: Runnable<any, any> = llm;
@@ -228,7 +230,7 @@ export async function callLlm(prompt: string, options: CallLlmOptions = {}): Pro
228230
}
229231

230232
const invokeOpts = signal ? { signal } : undefined;
231-
const provider = resolveProvider(model);
233+
const provider = resolveProvider(model, modelProvider);
232234
let result;
233235

234236
if (provider.id === 'anthropic') {
@@ -287,6 +289,7 @@ function annotateSystemMessageForCaching(messages: BaseMessage[]): BaseMessage[]
287289

288290
interface CallLlmWithMessagesOptions {
289291
model?: string;
292+
modelProvider?: string;
290293
tools?: StructuredToolInterface[];
291294
signal?: AbortSignal;
292295
}
@@ -306,9 +309,9 @@ export async function callLlmWithMessages(
306309
messages: BaseMessage[],
307310
options: CallLlmWithMessagesOptions = {},
308311
): Promise<LlmResult> {
309-
const { model = DEFAULT_MODEL, tools, signal } = options;
312+
const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options;
310313

311-
const llm = getChatModel(model, false);
314+
const llm = getChatModel(model, false, modelProvider);
312315

313316
// eslint-disable-next-line @typescript-eslint/no-explicit-any
314317
let runnable: Runnable<any, any> = llm;
@@ -318,7 +321,7 @@ export async function callLlmWithMessages(
318321
}
319322

320323
const invokeOpts = signal ? { signal } : undefined;
321-
const provider = resolveProvider(model);
324+
const provider = resolveProvider(model, modelProvider);
322325

323326
// For Anthropic: annotate SystemMessage with cache_control for prompt caching
324327
const finalMessages = provider.id === 'anthropic'
@@ -349,9 +352,9 @@ export async function* streamLlmWithMessages(
349352
messages: BaseMessage[],
350353
options: CallLlmWithMessagesOptions = {},
351354
): AsyncGenerator<AIMessageChunk, void> {
352-
const { model = DEFAULT_MODEL, tools, signal } = options;
355+
const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options;
353356

354-
const llm = getChatModel(model, true);
357+
const llm = getChatModel(model, true, modelProvider);
355358

356359
// eslint-disable-next-line @typescript-eslint/no-explicit-any
357360
let runnable: Runnable<any, any> = llm;
@@ -361,7 +364,7 @@ export async function* streamLlmWithMessages(
361364
}
362365

363366
const invokeOpts = signal ? { signal } : undefined;
364-
const provider = resolveProvider(model);
367+
const provider = resolveProvider(model, modelProvider);
365368

366369
const finalMessages = provider.id === 'anthropic'
367370
? annotateSystemMessageForCaching(messages)

src/providers.test.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { describe, expect, test } from 'bun:test';
2+
import { resolveProvider } from './providers.js';
3+
4+
describe('resolveProvider', () => {
5+
test('uses explicit provider before model prefix routing', () => {
6+
const provider = resolveProvider('deepseek-v4-flash', 'openai');
7+
8+
expect(provider.id).toBe('openai');
9+
});
10+
11+
test('falls back to model prefix routing without override', () => {
12+
const provider = resolveProvider('deepseek-v4-flash');
13+
14+
expect(provider.id).toBe('deepseek');
15+
});
16+
17+
test('ignores unknown provider override', () => {
18+
const provider = resolveProvider('claude-sonnet-4-5', 'unknown-provider');
19+
20+
expect(provider.id).toBe('anthropic');
21+
});
22+
});

src/providers.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,16 @@ export const PROVIDERS: ProviderDef[] = [
8686
const defaultProvider = PROVIDERS.find((p) => p.id === 'openai')!;
8787

8888
/**
89-
* Resolve the provider for a given model name based on its prefix.
90-
* Falls back to OpenAI when no prefix matches.
89+
* Resolve the provider for a given model name.
90+
* Explicit provider settings take precedence over model-name prefix routing.
91+
* Falls back to OpenAI when no prefix matches or an override is unknown.
9192
*/
92-
export function resolveProvider(modelName: string): ProviderDef {
93+
export function resolveProvider(modelName: string, providerOverride?: string): ProviderDef {
94+
if (providerOverride) {
95+
const provider = getProviderById(providerOverride);
96+
if (provider) return provider;
97+
}
98+
9399
return (
94100
PROVIDERS.find((p) => p.modelPrefix && modelName.startsWith(p.modelPrefix)) ??
95101
defaultProvider

src/utils/tokens.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ const DEFAULT_CONTEXT_WINDOW = 128_000;
3636
* Get the effective context window size for a model, accounting for
3737
* reserved output tokens.
3838
*/
39-
export function getEffectiveContextWindow(model: string): number {
40-
const provider = resolveProvider(model);
39+
export function getEffectiveContextWindow(model: string, modelProvider?: string): number {
40+
const provider = resolveProvider(model, modelProvider);
4141
const contextWindow = provider.contextWindow ?? DEFAULT_CONTEXT_WINDOW;
4242
return contextWindow - MAX_OUTPUT_TOKENS_FOR_SUMMARY;
4343
}
@@ -47,8 +47,8 @@ export function getEffectiveContextWindow(model: string): number {
4747
* This is the token count at which compaction should trigger.
4848
* Formula: effectiveWindow - 13K buffer.
4949
*/
50-
export function getAutoCompactThreshold(model: string): number {
51-
return getEffectiveContextWindow(model) - AUTOCOMPACT_BUFFER_TOKENS;
50+
export function getAutoCompactThreshold(model: string, modelProvider?: string): number {
51+
return getEffectiveContextWindow(model, modelProvider) - AUTOCOMPACT_BUFFER_TOKENS;
5252
}
5353

5454
// ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)