Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 48 additions & 39 deletions src/backend/src/modules/puterai/AIChatService.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AIChatService extends BaseService {
};

/** @type {import('../../services/MeteringService/MeteringService').MeteringService} */
get meteringService(){
get meteringService () {
return this.services.get('meteringService').meteringService;
}
/**
Expand All @@ -58,14 +58,14 @@ class AIChatService extends BaseService {
* Called during service instantiation.
* @private
*/
_construct() {
_construct () {
this.providers = [];
this.simple_model_list = [];
this.detail_model_list = [];
this.detail_model_map = {};
}

get_model_details(model_name, context) {
get_model_details (model_name, context) {
let model_details = this.detail_model_map[model_name];
if ( Array.isArray(model_details) && context ) {
for ( const model of model_details ) {
Expand All @@ -88,7 +88,7 @@ class AIChatService extends BaseService {
* as well as an empty object for the detailed model map.
* @private
*/
_init() {
_init () {
this.kvkey = this.modules.uuidv4();

this.db = this.services.get('database').get(DB_WRITE, 'ai-usage');
Expand Down Expand Up @@ -117,7 +117,7 @@ class AIChatService extends BaseService {
*
* @returns {Promise<void>}
*/
async ['__on_boot.consolidation']() {
async ['__on_boot.consolidation'] () {
{
const svc_driver = this.services.get('driver');
for ( const provider of this.providers ) {
Expand Down Expand Up @@ -215,13 +215,13 @@ class AIChatService extends BaseService {
}
}

register_provider(spec) {
register_provider (spec) {
this.providers.push(spec);
}

static IMPLEMENTS = {
['driver-capabilities']: {
supports_test_mode(iface, method_name) {
supports_test_mode (iface, method_name) {
return iface === 'puter-chat-completion' &&
method_name === 'complete';
},
Expand Down Expand Up @@ -253,7 +253,7 @@ class AIChatService extends BaseService {
*
* @returns {Promise<Array<Object>>} Array of model objects with details like id, provider, cost, etc.
*/
async models() {
async models () {
const delegate = this.get_delegate();
if ( ! delegate ) return await this.models_();
return await delegate.models();
Expand All @@ -264,7 +264,7 @@ class AIChatService extends BaseService {
* detail.
* @returns {Promise<Array<string>} Array of model objects with basic details
*/
async list() {
async list () {
const delegate = this.get_delegate();
if ( ! delegate ) return await this.list_();
return await delegate.list();
Expand Down Expand Up @@ -297,7 +297,7 @@ class AIChatService extends BaseService {
* @param {string} options.model - The name of a model to use
* @returns {{stream: boolean, [k:string]: unknown}} Returns either an object with stream:true property or a completion object
*/
async complete(parameters) {
async complete (parameters) {
const client_driver_call = Context.get('client_driver_call');
let { test_mode, intended_service, response_metadata } = client_driver_call;

Expand All @@ -323,17 +323,17 @@ class AIChatService extends BaseService {
}

// Skip moderation for Ollama (local service) and other local services
const should_moderate = ! test_mode &&
intended_service !== 'ollama' &&
! parameters.model?.startsWith('ollama:');
if ( should_moderate && ! await this.moderate(parameters) ) {
const should_moderate = !test_mode &&
intended_service !== 'ollama' &&
!parameters.model?.startsWith('ollama:');

if ( should_moderate && !await this.moderate(parameters) ) {
test_mode = true;
throw APIError.create('moderation_failed');
}

// Only set moderated flag if we actually ran moderation
if ( ! test_mode && should_moderate ) {
if ( !test_mode && should_moderate ) {
Context.set('moderated', true);
}

Expand Down Expand Up @@ -367,7 +367,7 @@ class AIChatService extends BaseService {

if ( ! model_details ) {
// TODO (xiaochen): replace with a standard link
const available_models_url = this.global_config.origin + '/puterai/chat/models';
const available_models_url = `${this.global_config.origin }/puterai/chat/models`;

throw APIError.create('field_invalid', null, {
key: 'model',
Expand All @@ -384,7 +384,7 @@ class AIChatService extends BaseService {
const usageAllowed = await this.meteringService.hasEnoughCredits(actor, approximate_input_cost);

// Handle usage limits reached case
if ( !usageAllowed ) {
if ( ! usageAllowed ) {
// The check_usage_ method has eady updated the intended_service to 'usage-limited-chat'
service_used = 'usage-limited-chat';
model_used = 'usage-limited';
Expand All @@ -404,7 +404,7 @@ class AIChatService extends BaseService {
parameters.max_tokens = Math.floor(Math.min(parameters.max_tokens ?? Number.POSITIVE_INFINITY,
max_allowed_output_tokens,
model_max_tokens - (Math.ceil(text.length / 4))));
if (parameters.max_tokens < 1) {
if ( parameters.max_tokens < 1 ) {
parameters.max_tokens = undefined;
}
}
Expand Down Expand Up @@ -495,7 +495,7 @@ class AIChatService extends BaseService {
const fallbackUsageAllowed = await this.meteringService.hasEnoughCredits(actor, 1);

// If usage not allowed for fallback, use usage-limited-chat instead
if ( !fallbackUsageAllowed ) {
if ( ! fallbackUsageAllowed ) {
// The check_usage_ method has already updated intended_service
service_used = 'usage-limited-chat';
model_used = 'usage-limited';
Expand Down Expand Up @@ -576,10 +576,10 @@ class AIChatService extends BaseService {
this.errors.report('error during stream response', {
source: e,
});
stream.write(JSON.stringify({
stream.write(`${JSON.stringify({
type: 'error',
message: e.message,
}) + '\n');
}) }\n`);
stream.end();
} finally {
if ( ret.result.finally_fn ) {
Expand Down Expand Up @@ -631,7 +631,7 @@ class AIChatService extends BaseService {
* Returns false immediately if any message is flagged as inappropriate.
* Returns true if OpenAI service is unavailable or all messages pass moderation.
*/
async moderate({ messages }) {
async moderate ({ messages }) {
if ( process.env.TEST_MODERATION_FAILURE ) return false;
const fulltext = Messages.extract_text(messages);
let mod_last_error = null;
Expand Down Expand Up @@ -672,15 +672,15 @@ class AIChatService extends BaseService {
return true;
}

async models_() {
async models_ () {
return this.detail_model_list;
}

/**
* Returns a list of available AI models with basic details
* @returns {Promise<Array>} Array of simple model objects containing basic model information
*/
async list_() {
async list_ () {
return this.simple_model_list;
}

Expand All @@ -691,7 +691,7 @@ class AIChatService extends BaseService {
*
* @returns {Object|undefined} The delegate service or undefined if intended service is ai-chat
*/
get_delegate() {
get_delegate () {
const client_driver_call = Context.get('client_driver_call');
if ( client_driver_call.intended_service === this.service_name ) {
return undefined;
Expand All @@ -709,8 +709,9 @@ class AIChatService extends BaseService {
* @param {*} param0
* @returns
*/
get_fallback_model({ model, tried }) {
get_fallback_model ({ model, tried }) {
let target_model = this.detail_model_map[model];

if ( ! target_model ) {
this.log.error('could not find model', { model });
throw new Error('could not find model');
Expand All @@ -722,24 +723,32 @@ class AIChatService extends BaseService {
}

// First check KV for the sorted list
let sorted_models = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`);
let potentialFallbacks = this.modules.kv.get(`${this.kvkey}:fallbacks:${model}`);

if ( ! sorted_models ) {
if ( ! potentialFallbacks ) {
// Calculate the sorted list
const models = this.detail_model_list;

sorted_models = models.toSorted((a, b) => {
return Math.sqrt(Math.pow(a.cost.input - target_model.cost.input, 2) +
Math.pow(a.cost.output - target_model.cost.output, 2)) - Math.sqrt(Math.pow(b.cost.input - target_model.cost.input, 2) +
Math.pow(b.cost.output - target_model.cost.output, 2));
});
let aiProvider, modelToSearch;
if ( target_model.id.startsWith('openrouter:') || target_model.id.startsWith('togetherai:') ) {
[aiProvider, modelToSearch] = target_model.id.replace('openrouter:', '').replace('togetherai:', '').toLowerCase().split('/');
} else {
[aiProvider, modelToSearch] = target_model.provider.toLowerCase(), target_model.id.toLowerCase();
}

const potentialMatches = models.filter(model => {
const possibleModelNames = [`openrouter:${aiProvider}/${modelToSearch}`,
`togetherai:${aiProvider}/${modelToSearch}`, ...(target_model.aliases?.map((alias) => [`openrouter:${aiProvider}/${alias}`,
`togetherai:${aiProvider}/${alias}`])?.flat() ?? [])];

sorted_models = sorted_models.slice(0, MAX_FALLBACKS);
return !possibleModelNames.find(possibleName => model.id.toLowerCase() === possibleName);
}).slice(0, MAX_FALLBACKS);

this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, sorted_models);
this.modules.kv.set(`${this.kvkey}:fallbacks:${model}`, potentialMatches);
potentialFallbacks = potentialMatches;
}

for ( const model of sorted_models ) {
for ( const model of potentialFallbacks ) {
if ( tried.includes(model.id) ) continue;
if ( model.provider === 'fake-chat' ) continue;

Expand All @@ -751,12 +760,12 @@ class AIChatService extends BaseService {

// No fallbacks available
this.log.error('no fallbacks', {
sorted_models,
potentialFallbacks,
tried,
});
}

get_model_from_request(parameters, modified_context = {}) {
get_model_from_request (parameters, modified_context = {}) {
const client_driver_call = Context.get('client_driver_call');
let { intended_service } = client_driver_call;

Expand Down