Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 201 additions & 2 deletions src/services/api/openaiShim.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,97 @@ test('preserves usage from final OpenAI stream chunk with empty choices', async
expect(usageEvent?.usage?.output_tokens).toBe(45)
})

test('uses max_tokens instead of max_completion_tokens for local providers', async () => {
process.env.OPENAI_BASE_URL = 'http://localhost:11434/v1'

globalThis.fetch = (async (_input, init) => {
const body = JSON.parse(String(init?.body))
expect(body.max_tokens).toBe(64)
expect(body.max_completion_tokens).toBeUndefined()
expect(body.stream_options).toBeUndefined()

return new Response(
JSON.stringify({
id: 'chatcmpl-1',
model: 'llama3.1:8b',
choices: [
{
message: {
role: 'assistant',
content: 'hello',
},
finish_reason: 'stop',
},
],
usage: {
prompt_tokens: 5,
completion_tokens: 1,
total_tokens: 6,
},
}),
{
headers: {
'Content-Type': 'application/json',
},
},
)
}) as FetchType

const client = createOpenAIShimClient({}) as OpenAIShimClient

await client.beta.messages.create({
model: 'llama3.1:8b',
messages: [{ role: 'user', content: 'hello' }],
max_tokens: 64,
stream: false,
})
})

test('keeps max_completion_tokens for non-local non-github providers', async () => {
process.env.OPENAI_BASE_URL = 'https://api.openai.com/v1'

globalThis.fetch = (async (_input, init) => {
const body = JSON.parse(String(init?.body))
expect(body.max_completion_tokens).toBe(64)
expect(body.max_tokens).toBeUndefined()

return new Response(
JSON.stringify({
id: 'chatcmpl-1',
model: 'gpt-4o',
choices: [
{
message: {
role: 'assistant',
content: 'hello',
},
finish_reason: 'stop',
},
],
usage: {
prompt_tokens: 5,
completion_tokens: 1,
total_tokens: 6,
},
}),
{
headers: {
'Content-Type': 'application/json',
},
},
)
}) as FetchType

const client = createOpenAIShimClient({}) as OpenAIShimClient

await client.beta.messages.create({
model: 'gpt-4o',
messages: [{ role: 'user', content: 'hello' }],
max_tokens: 64,
stream: false,
})
})

test('preserves Gemini tool call extra_content in follow-up requests', async () => {
let requestBody: Record<string, unknown> | undefined

Expand Down Expand Up @@ -689,9 +780,117 @@ test('preserves image tool results as placeholders in follow-up requests', async

const toolMessage = (requestBody?.messages as Array<Record<string, unknown>>).find(
message => message.role === 'tool',
) as { content?: string } | undefined
) as {
content?: Array<{
type: string
text?: string
image_url?: { url: string }
}> | string
} | undefined

expect(Array.isArray(toolMessage?.content)).toBe(true)
const parts = toolMessage?.content as Array<{
type: string
text?: string
image_url?: { url: string }
}>
const imagePart = parts.find(part => part.type === 'image_url')
expect(imagePart?.image_url?.url).toBe('data:image/png;base64,ZmFrZQ==')
})

expect(toolMessage?.content).toContain('[image:image/png]')
test('preserves mixed text and image tool results as multipart content', async () => {
let requestBody: Record<string, unknown> | undefined

globalThis.fetch = (async (_input, init) => {
requestBody = JSON.parse(String(init?.body))

return new Response(
JSON.stringify({
id: 'chatcmpl-1',
model: 'gpt-4o',
choices: [
{
message: {
role: 'assistant',
content: 'done',
},
finish_reason: 'stop',
},
],
usage: {
prompt_tokens: 12,
completion_tokens: 4,
total_tokens: 16,
},
}),
{
headers: {
'Content-Type': 'application/json',
},
},
)
}) as FetchType

const client = createOpenAIShimClient({}) as OpenAIShimClient

await client.beta.messages.create({
model: 'gpt-4o',
system: 'test system',
messages: [
{ role: 'user', content: 'Read this screenshot' },
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_image_2',
name: 'Read',
input: { file_path: 'C:\\temp\\screenshot.png' },
},
],
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'call_image_2',
content: [
{ type: 'text', text: 'Screenshot captured' },
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/png',
data: 'ZmFrZQ==',
},
},
],
},
],
},
],
max_tokens: 64,
stream: false,
})

const toolMessage = (requestBody?.messages as Array<Record<string, unknown>>).find(
message => message.role === 'tool',
) as {
content?: Array<{
type: string
text?: string
image_url?: { url: string }
}>
} | undefined

expect(Array.isArray(toolMessage?.content)).toBe(true)
const parts = toolMessage?.content ?? []
expect(parts[0]).toEqual({ type: 'text', text: 'Screenshot captured' })
expect(parts[1]).toEqual({
type: 'image_url',
image_url: { url: 'data:image/png;base64,ZmFrZQ==' },
})
})

test('uses GEMINI_ACCESS_TOKEN for Gemini OpenAI-compatible requests', async () => {
Expand Down
56 changes: 41 additions & 15 deletions src/services/api/openaiShim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,35 +176,61 @@ function convertSystemPrompt(
return String(system)
}

function convertToolResultContent(content: unknown): string {
if (typeof content === 'string') return content
if (!Array.isArray(content)) return JSON.stringify(content ?? '')
function convertToolResultContent(
content: unknown,
isError?: boolean,
): string | Array<{ type: string; text?: string; image_url?: { url: string } }> {
if (typeof content === 'string') {
return isError ? `Error: ${content}` : content
}
if (!Array.isArray(content)) {
const text = JSON.stringify(content ?? '')
return isError ? `Error: ${text}` : text
}

const chunks: string[] = []
const parts: Array<{
type: string
text?: string
image_url?: { url: string }
}> = []
for (const block of content) {
if (block?.type === 'text' && typeof block.text === 'string') {
chunks.push(block.text)
parts.push({ type: 'text', text: block.text })
continue
}

if (block?.type === 'image') {
const source = block.source
if (source?.type === 'url' && source.url) {
chunks.push(`[Image](${source.url})`)
} else if (source?.type === 'base64') {
chunks.push(`[image:${source.media_type ?? 'unknown'}]`)
} else {
chunks.push('[image]')
parts.push({ type: 'image_url', image_url: { url: source.url } })
} else if (source?.type === 'base64' && source.media_type && source.data) {
parts.push({
type: 'image_url',
image_url: {
url: `data:${source.media_type};base64,${source.data}`,
},
})
}
continue
}

if (typeof block?.text === 'string') {
chunks.push(block.text)
parts.push({ type: 'text', text: block.text })
}
}

return chunks.join('\n')
if (parts.length === 0) return ''
if (parts.length === 1 && parts[0].type === 'text') {
const text = parts[0].text ?? ''
return isError ? `Error: ${text}` : text
}
if (isError && parts[0]?.type === 'text') {
parts[0] = { ...parts[0], text: `Error: ${parts[0].text ?? ''}` }
} else if (isError) {
parts.unshift({ type: 'text', text: 'Error:' })
}

return parts
}

function convertContentBlocks(
Expand Down Expand Up @@ -292,11 +318,10 @@ function convertMessages(

// Emit tool results as tool messages
for (const tr of toolResults) {
const trContent = convertToolResultContent(tr.content)
result.push({
role: 'tool',
tool_call_id: tr.tool_use_id ?? 'unknown',
content: tr.is_error ? `Error: ${trContent}` : trContent,
content: convertToolResultContent(tr.content, tr.is_error),
})
}

Expand Down Expand Up @@ -1216,12 +1241,13 @@ class OpenAIShimMessages {

const isGithub = isGithubModelsMode()
const isMistral = isMistralMode()
const isLocal = isLocalProviderUrl(request.baseUrl)

const githubEndpointType = getGithubEndpointType(request.baseUrl)
const isGithubCopilot = isGithub && githubEndpointType === 'copilot'
const isGithubModels = isGithub && (githubEndpointType === 'models' || githubEndpointType === 'custom')

if ((isGithub || isMistral) && body.max_completion_tokens !== undefined) {
if ((isGithub || isMistral || isLocal) && body.max_completion_tokens !== undefined) {
body.max_tokens = body.max_completion_tokens
delete body.max_completion_tokens
}
Expand Down
Loading