Skip to content

Commit f177397

Browse files
authored
fix: only fallback to models that are immediately the same (#2001)
1 parent f6139bc commit f177397

File tree

1 file changed

+48
-39
lines changed

1 file changed

+48
-39
lines changed

src/backend/src/modules/puterai/AIChatService.js

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class AIChatService extends BaseService {
4848
};
4949

5050
/** @type {import('../../services/MeteringService/MeteringService').MeteringService} */
51-
get meteringService(){
51+
get meteringService () {
5252
return this.services.get('meteringService').meteringService;
5353
}
5454
/**
@@ -58,14 +58,14 @@ class AIChatService extends BaseService {
5858
* Called during service instantiation.
5959
* @private
6060
*/
61-
_construct() {
61+
_construct () {
6262
this.providers = [];
6363
this.simple_model_list = [];
6464
this.detail_model_list = [];
6565
this.detail_model_map = {};
6666
}
6767

68-
get_model_details(model_name, context) {
68+
get_model_details (model_name, context) {
6969
let model_details = this.detail_model_map[model_name];
7070
if ( Array.isArray(model_details) && context ) {
7171
for ( const model of model_details ) {
@@ -88,7 +88,7 @@ class AIChatService extends BaseService {
8888
* as well as an empty object for the detailed model map.
8989
* @private
9090
*/
91-
_init() {
91+
_init () {
9292
this.kvkey = this.modules.uuidv4();
9393

9494
this.db = this.services.get('database').get(DB_WRITE, 'ai-usage');
@@ -117,7 +117,7 @@ class AIChatService extends BaseService {
117117
*
118118
* @returns {Promise<void>}
119119
*/
120-
async ['__on_boot.consolidation']() {
120+
async ['__on_boot.consolidation'] () {
121121
{
122122
const svc_driver = this.services.get('driver');
123123
for ( const provider of this.providers ) {
@@ -215,13 +215,13 @@ class AIChatService extends BaseService {
215215
}
216216
}
217217

218-
register_provider(spec) {
218+
register_provider (spec) {
219219
this.providers.push(spec);
220220
}
221221

222222
static IMPLEMENTS = {
223223
['driver-capabilities']: {
224-
supports_test_mode(iface, method_name) {
224+
supports_test_mode (iface, method_name) {
225225
return iface === 'puter-chat-completion' &&
226226
method_name === 'complete';
227227
},
@@ -253,7 +253,7 @@ class AIChatService extends BaseService {
253253
*
254254
* @returns {Promise<Array<Object>>} Array of model objects with details like id, provider, cost, etc.
255255
*/
256-
async models() {
256+
async models () {
257257
const delegate = this.get_delegate();
258258
if ( ! delegate ) return await this.models_();
259259
return await delegate.models();
@@ -264,7 +264,7 @@ class AIChatService extends BaseService {
264264
* detail.
265265
* @returns {Promise<Array<string>} Array of model objects with basic details
266266
*/
267-
async list() {
267+
async list () {
268268
const delegate = this.get_delegate();
269269
if ( ! delegate ) return await this.list_();
270270
return await delegate.list();
@@ -297,7 +297,7 @@ class AIChatService extends BaseService {
297297
* @param {string} options.model - The name of a model to use
298298
* @returns {{stream: boolean, [k:string]: unknown}} Returns either an object with stream:true property or a completion object
299299
*/
300-
async complete(parameters) {
300+
async complete (parameters) {
301301
const client_driver_call = Context.get('client_driver_call');
302302
let { test_mode, intended_service, response_metadata } = client_driver_call;
303303

@@ -323,17 +323,17 @@ class AIChatService extends BaseService {
323323
}
324324

325325
// Skip moderation for Ollama (local service) and other local services
326-
const should_moderate = ! test_mode &&
327-
intended_service !== 'ollama' &&
328-
! parameters.model?.startsWith('ollama:');
329-
330-
if ( should_moderate && ! await this.moderate(parameters) ) {
326+
const should_moderate = !test_mode &&
327+
intended_service !== 'ollama' &&
328+
!parameters.model?.startsWith('ollama:');
329+
330+
if ( should_moderate && !await this.moderate(parameters) ) {
331331
test_mode = true;
332332
throw APIError.create('moderation_failed');
333333
}
334334

335335
// Only set moderated flag if we actually ran moderation
336-
if ( ! test_mode && should_moderate ) {
336+
if ( !test_mode && should_moderate ) {
337337
Context.set('moderated', true);
338338
}
339339

@@ -367,7 +367,7 @@ class AIChatService extends BaseService {
367367

368368
if ( ! model_details ) {
369369
// TODO (xiaochen): replace with a standard link
370-
const available_models_url = this.global_config.origin + '/puterai/chat/models';
370+
const available_models_url = `${this.global_config.origin }/puterai/chat/models`;
371371

372372
throw APIError.create('field_invalid', null, {
373373
key: 'model',
@@ -384,7 +384,7 @@ class AIChatService extends BaseService {
384384
const usageAllowed = await this.meteringService.hasEnoughCredits(actor, approximate_input_cost);
385385

386386
// Handle usage limits reached case
387-
if ( !usageAllowed ) {
387+
if ( ! usageAllowed ) {
388388
// The check_usage_ method has eady updated the intended_service to 'usage-limited-chat'
389389
service_used = 'usage-limited-chat';
390390
model_used = 'usage-limited';
@@ -404,7 +404,7 @@ class AIChatService extends BaseService {
404404
parameters.max_tokens = Math.floor(Math.min(parameters.max_tokens ?? Number.POSITIVE_INFINITY,
405405
max_allowed_output_tokens,
406406
model_max_tokens - (Math.ceil(text.length / 4))));
407-
if (parameters.max_tokens < 1) {
407+
if ( parameters.max_tokens < 1 ) {
408408
parameters.max_tokens = undefined;
409409
}
410410
}
@@ -495,7 +495,7 @@ class AIChatService extends BaseService {
495495
const fallbackUsageAllowed = await this.meteringService.hasEnoughCredits(actor, 1);
496496

497497
// If usage not allowed for fallback, use usage-limited-chat instead
498-
if ( !fallbackUsageAllowed ) {
498+
if ( ! fallbackUsageAllowed ) {
499499
// The check_usage_ method has already updated intended_service
500500
service_used = 'usage-limited-chat';
501501
model_used = 'usage-limited';
@@ -576,10 +576,10 @@ class AIChatService extends BaseService {
576576
this.errors.report('error during stream response', {
577577
source: e,
578578
});
579-
stream.write(JSON.stringify({
579+
stream.write(`${JSON.stringify({
580580
type: 'error',
581581
message: e.message,
582-
}) + '\n');
582+
}) }\n`);
583583
stream.end();
584584
} finally {
585585
if ( ret.result.finally_fn ) {
@@ -631,7 +631,7 @@ class AIChatService extends BaseService {
631631
* Returns false immediately if any message is flagged as inappropriate.
632632
* Returns true if OpenAI service is unavailable or all messages pass moderation.
633633
*/
634-
async moderate({ messages }) {
634+
async moderate ({ messages }) {
635635
if ( process.env.TEST_MODERATION_FAILURE ) return false;
636636
const fulltext = Messages.extract_text(messages);
637637
let mod_last_error = null;
@@ -672,15 +672,15 @@ class AIChatService extends BaseService {
672672
return true;
673673
}
674674

675-
async models_() {
675+
async models_ () {
676676
return this.detail_model_list;
677677
}
678678

679679
/**
680680
* Returns a list of available AI models with basic details
681681
* @returns {Promise<Array>} Array of simple model objects containing basic model information
682682
*/
683-
async list_() {
683+
async list_ () {
684684
return this.simple_model_list;
685685
}
686686

@@ -691,7 +691,7 @@ class AIChatService extends BaseService {
691691
*
692692
* @returns {Object|undefined} The delegate service or undefined if intended service is ai-chat
693693
*/
694-
get_delegate() {
694+
get_delegate () {
695695
const client_driver_call = Context.get('client_driver_call');
696696
if ( client_driver_call.intended_service === this.service_name ) {
697697
return undefined;
@@ -709,8 +709,9 @@ class AIChatService extends BaseService {
709709
* @param {*} param0
710710
* @returns
711711
*/
712-
get_fallback_model({ model, tried }) {
712+
get_fallback_model ({ model, tried }) {
713713
let target_model = this.detail_model_map[model];
714+
714715
if ( ! target_model ) {
715716
this.log.error('could not find model', { model });
716717
throw new Error('could not find model');
@@ -722,24 +723,32 @@ class AIChatService extends BaseService {
722723
}
723724

724725
// First check KV for the sorted list
725-
let sorted_models = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`);
726+
let potentialFallbacks = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`);
726727

727-
if ( ! sorted_models ) {
728+
if ( ! potentialFallbacks ) {
728729
// Calculate the sorted list
729730
const models = this.detail_model_list;
730731

731-
sorted_models = models.toSorted((a, b) => {
732-
return Math.sqrt(Math.pow(a.cost.input - target_model.cost.input, 2) +
733-
Math.pow(a.cost.output - target_model.cost.output, 2)) - Math.sqrt(Math.pow(b.cost.input - target_model.cost.input, 2) +
734-
Math.pow(b.cost.output - target_model.cost.output, 2));
735-
});
732+
let aiProvider, modelToSearch;
733+
if ( target_model.id.startsWith('openrouter:') || target_model.id.startsWith('togetherai:') ) {
734+
[aiProvider, modelToSearch] = target_model.id.replace('openrouter:', '').replace('togetherai:', '').toLowerCase().split('/');
735+
} else {
736+
[aiProvider, modelToSearch] = target_model.provider.toLowerCase(), target_model.id.toLowerCase();
737+
}
738+
739+
const potentialMatches = models.filter(model => {
740+
const possibleModelNames = [`openrouter:${aiProvider}/${modelToSearch}`,
741+
`togetherai:${aiProvider}/${modelToSearch}`, ...(target_model.aliases?.map((alias) => [`openrouter:${aiProvider}/${alias}`,
742+
`togetherai:${aiProvider}/${alias}`])?.flat() ?? [])];
736743

737-
sorted_models = sorted_models.slice(0, MAX_FALLBACKS);
744+
return !possibleModelNames.find(possibleName => model.id.toLowerCase() === possibleName);
745+
}).slice(0, MAX_FALLBACKS);
738746

739-
this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, sorted_models);
747+
this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, potentialMatches);
748+
potentialFallbacks = potentialMatches;
740749
}
741750

742-
for ( const model of sorted_models ) {
751+
for ( const model of potentialFallbacks ) {
743752
if ( tried.includes(model.id) ) continue;
744753
if ( model.provider === 'fake-chat' ) continue;
745754

@@ -751,12 +760,12 @@ class AIChatService extends BaseService {
751760

752761
// No fallbacks available
753762
this.log.error('no fallbacks', {
754-
sorted_models,
763+
potentialFallbacks,
755764
tried,
756765
});
757766
}
758767

759-
get_model_from_request(parameters, modified_context = {}) {
768+
get_model_from_request (parameters, modified_context = {}) {
760769
const client_driver_call = Context.get('client_driver_call');
761770
let { intended_service } = client_driver_call;
762771

0 commit comments

Comments
 (0)