From bbfd232060c682551f6db8dc6098db1d01d1ebfb Mon Sep 17 00:00:00 2001 From: Daniel Salazar Date: Fri, 21 Nov 2025 10:43:38 -0800 Subject: [PATCH] fix: only fallback to models that are immediately the same --- .../src/modules/puterai/AIChatService.js | 87 ++++++++++--------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/src/backend/src/modules/puterai/AIChatService.js b/src/backend/src/modules/puterai/AIChatService.js index e348fa6b4a..2128b8bbd8 100644 --- a/src/backend/src/modules/puterai/AIChatService.js +++ b/src/backend/src/modules/puterai/AIChatService.js @@ -48,7 +48,7 @@ class AIChatService extends BaseService { }; /** @type {import('../../services/MeteringService/MeteringService').MeteringService} */ - get meteringService(){ + get meteringService () { return this.services.get('meteringService').meteringService; } /** @@ -58,14 +58,14 @@ class AIChatService extends BaseService { * Called during service instantiation. * @private */ - _construct() { + _construct () { this.providers = []; this.simple_model_list = []; this.detail_model_list = []; this.detail_model_map = {}; } - get_model_details(model_name, context) { + get_model_details (model_name, context) { let model_details = this.detail_model_map[model_name]; if ( Array.isArray(model_details) && context ) { for ( const model of model_details ) { @@ -88,7 +88,7 @@ class AIChatService extends BaseService { * as well as an empty object for the detailed model map. * @private */ - _init() { + _init () { this.kvkey = this.modules.uuidv4(); this.db = this.services.get('database').get(DB_WRITE, 'ai-usage'); @@ -117,7 +117,7 @@ class AIChatService extends BaseService { * * @returns {Promise} */ - async ['__on_boot.consolidation']() { + async ['__on_boot.consolidation'] () { { const svc_driver = this.services.get('driver'); for ( const provider of this.providers ) { @@ -215,13 +215,13 @@ class AIChatService extends BaseService { } } - register_provider(spec) { + register_provider (spec) { this.providers.push(spec); } static IMPLEMENTS = { ['driver-capabilities']: { - supports_test_mode(iface, method_name) { + supports_test_mode (iface, method_name) { return iface === 'puter-chat-completion' && method_name === 'complete'; }, @@ -253,7 +253,7 @@ class AIChatService extends BaseService { * * @returns {Promise>} Array of model objects with details like id, provider, cost, etc. */ - async models() { + async models () { const delegate = this.get_delegate(); if ( ! delegate ) return await this.models_(); return await delegate.models(); @@ -264,7 +264,7 @@ class AIChatService extends BaseService { * detail. * @returns {Promise} Array of model objects with basic details */ - async list() { + async list () { const delegate = this.get_delegate(); if ( ! delegate ) return await this.list_(); return await delegate.list(); @@ -297,7 +297,7 @@ class AIChatService extends BaseService { * @param {string} options.model - The name of a model to use * @returns {{stream: boolean, [k:string]: unknown}} Returns either an object with stream:true property or a completion object */ - async complete(parameters) { + async complete (parameters) { const client_driver_call = Context.get('client_driver_call'); let { test_mode, intended_service, response_metadata } = client_driver_call; @@ -323,17 +323,17 @@ class AIChatService extends BaseService { } // Skip moderation for Ollama (local service) and other local services - const should_moderate = ! test_mode && - intended_service !== 'ollama' && - ! parameters.model?.startsWith('ollama:'); - - if ( should_moderate && ! await this.moderate(parameters) ) { + const should_moderate = !test_mode && + intended_service !== 'ollama' && + !parameters.model?.startsWith('ollama:'); + + if ( should_moderate && !await this.moderate(parameters) ) { test_mode = true; throw APIError.create('moderation_failed'); } // Only set moderated flag if we actually ran moderation - if ( ! test_mode && should_moderate ) { + if ( !test_mode && should_moderate ) { Context.set('moderated', true); } @@ -367,7 +367,7 @@ class AIChatService extends BaseService { if ( ! model_details ) { // TODO (xiaochen): replace with a standard link - const available_models_url = this.global_config.origin + '/puterai/chat/models'; + const available_models_url = `${this.global_config.origin }/puterai/chat/models`; throw APIError.create('field_invalid', null, { key: 'model', @@ -384,7 +384,7 @@ class AIChatService extends BaseService { const usageAllowed = await this.meteringService.hasEnoughCredits(actor, approximate_input_cost); // Handle usage limits reached case - if ( !usageAllowed ) { + if ( ! usageAllowed ) { // The check_usage_ method has eady updated the intended_service to 'usage-limited-chat' service_used = 'usage-limited-chat'; model_used = 'usage-limited'; @@ -404,7 +404,7 @@ class AIChatService extends BaseService { parameters.max_tokens = Math.floor(Math.min(parameters.max_tokens ?? Number.POSITIVE_INFINITY, max_allowed_output_tokens, model_max_tokens - (Math.ceil(text.length / 4)))); - if (parameters.max_tokens < 1) { + if ( parameters.max_tokens < 1 ) { parameters.max_tokens = undefined; } } @@ -495,7 +495,7 @@ class AIChatService extends BaseService { const fallbackUsageAllowed = await this.meteringService.hasEnoughCredits(actor, 1); // If usage not allowed for fallback, use usage-limited-chat instead - if ( !fallbackUsageAllowed ) { + if ( ! fallbackUsageAllowed ) { // The check_usage_ method has already updated intended_service service_used = 'usage-limited-chat'; model_used = 'usage-limited'; @@ -576,10 +576,10 @@ class AIChatService extends BaseService { this.errors.report('error during stream response', { source: e, }); - stream.write(JSON.stringify({ + stream.write(`${JSON.stringify({ type: 'error', message: e.message, - }) + '\n'); + }) }\n`); stream.end(); } finally { if ( ret.result.finally_fn ) { @@ -631,7 +631,7 @@ class AIChatService extends BaseService { * Returns false immediately if any message is flagged as inappropriate. * Returns true if OpenAI service is unavailable or all messages pass moderation. */ - async moderate({ messages }) { + async moderate ({ messages }) { if ( process.env.TEST_MODERATION_FAILURE ) return false; const fulltext = Messages.extract_text(messages); let mod_last_error = null; @@ -672,7 +672,7 @@ class AIChatService extends BaseService { return true; } - async models_() { + async models_ () { return this.detail_model_list; } @@ -680,7 +680,7 @@ class AIChatService extends BaseService { * Returns a list of available AI models with basic details * @returns {Promise} Array of simple model objects containing basic model information */ - async list_() { + async list_ () { return this.simple_model_list; } @@ -691,7 +691,7 @@ class AIChatService extends BaseService { * * @returns {Object|undefined} The delegate service or undefined if intended service is ai-chat */ - get_delegate() { + get_delegate () { const client_driver_call = Context.get('client_driver_call'); if ( client_driver_call.intended_service === this.service_name ) { return undefined; @@ -709,8 +709,9 @@ class AIChatService extends BaseService { * @param {*} param0 * @returns */ - get_fallback_model({ model, tried }) { + get_fallback_model ({ model, tried }) { let target_model = this.detail_model_map[model]; + if ( ! target_model ) { this.log.error('could not find model', { model }); throw new Error('could not find model'); @@ -722,24 +723,32 @@ class AIChatService extends BaseService { } // First check KV for the sorted list - let sorted_models = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`); + let potentialFallbacks = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`); - if ( ! sorted_models ) { + if ( ! potentialFallbacks ) { // Calculate the sorted list const models = this.detail_model_list; - sorted_models = models.toSorted((a, b) => { - return Math.sqrt(Math.pow(a.cost.input - target_model.cost.input, 2) + - Math.pow(a.cost.output - target_model.cost.output, 2)) - Math.sqrt(Math.pow(b.cost.input - target_model.cost.input, 2) + - Math.pow(b.cost.output - target_model.cost.output, 2)); - }); + let aiProvider, modelToSearch; + if ( target_model.id.startsWith('openrouter:') || target_model.id.startsWith('togetherai:') ) { + [aiProvider, modelToSearch] = target_model.id.replace('openrouter:', '').replace('togetherai:', '').toLowerCase().split('/'); + } else { + [aiProvider, modelToSearch] = target_model.provider.toLowerCase(), target_model.id.toLowerCase(); + } + + const potentialMatches = models.filter(model => { + const possibleModelNames = [`openrouter:${aiProvider}/${modelToSearch}`, + `togetherai:${aiProvider}/${modelToSearch}`, ...(target_model.aliases?.map((alias) => [`openrouter:${aiProvider}/${alias}`, + `togetherai:${aiProvider}/${alias}`])?.flat() ?? [])]; - sorted_models = sorted_models.slice(0, MAX_FALLBACKS); + return !possibleModelNames.find(possibleName => model.id.toLowerCase() === possibleName); + }).slice(0, MAX_FALLBACKS); - this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, sorted_models); + this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, potentialMatches); + potentialFallbacks = potentialMatches; } - for ( const model of sorted_models ) { + for ( const model of potentialFallbacks ) { if ( tried.includes(model.id) ) continue; if ( model.provider === 'fake-chat' ) continue; @@ -751,12 +760,12 @@ class AIChatService extends BaseService { // No fallbacks available this.log.error('no fallbacks', { - sorted_models, + potentialFallbacks, tried, }); } - get_model_from_request(parameters, modified_context = {}) { + get_model_from_request (parameters, modified_context = {}) { const client_driver_call = Context.get('client_driver_call'); let { intended_service } = client_driver_call;