Skip to content

Commit 37c43f5

Browse files
authored
fix: Updated openai instrumentation to properly handle streaming when stream_options.include_usage is set (newrelic#3494)
1 parent f805d38 commit 37c43f5

File tree

9 files changed

+210
-53
lines changed

9 files changed

+210
-53
lines changed

lib/subscribers/openai/utils.js

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ function instrumentStream({ agent, headers, logger, request, response, segment,
252252
response.iterator = async function * wrappedIterator() {
253253
let content = ''
254254
let role = ''
255+
let finishReason = ''
255256
let chunk
256257
try {
257258
const iterator = orig.apply(this, arguments)
@@ -261,15 +262,23 @@ function instrumentStream({ agent, headers, logger, request, response, segment,
261262
role = chunk.choices[0].delta.role
262263
}
263264

265+
if (chunk.choices?.[0]?.finish_reason) {
266+
finishReason = chunk.choices[0].finish_reason
267+
}
268+
264269
content += chunk.choices?.[0]?.delta?.content ?? ''
265270
yield chunk
266271
}
267272
} catch (streamErr) {
268273
err = streamErr
269274
throw err
270275
} finally {
271-
if (chunk?.choices && chunk?.choices?.length !== 0) {
272-
chunk.choices[0].message = { role, content }
276+
// when `chunk.choices` is an array that means the completions API is being used
277+
// we must re-assign the finish reason, and construct a message object with role and content
278+
// This is because if `include_usage` is enabled, the last chunk only contains usage info and no message deltas
279+
if (Array.isArray(chunk?.choices)) {
280+
chunk.choices = [{ finish_reason: finishReason, message: { role, content } }]
281+
// This means it is the responses API and the entire message is in the response object
273282
} else if (chunk?.response) {
274283
chunk = chunk.response
275284
}

test/unit/llm-events/openai/embedding.test.js

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,19 @@ test('should set error to true', (t, end) => {
101101
})
102102
})
103103

104-
test('respects record_content', (t, end) => {
104+
test('respects record_content by not recording content when set to false', (t, end) => {
105105
const { agent } = t.nr
106106
const req = {
107107
input: 'This is my test input',
108108
model: 'gpt-3.5-turbo-0613'
109109
}
110110
agent.config.ai_monitoring.record_content.enabled = false
111+
function cb(model, content) {
112+
return 65
113+
}
114+
115+
const api = helper.getAgentApi()
116+
api.setLlmTokenCountCallback(cb)
111117

112118
helper.runInTransaction(agent, () => {
113119
const segment = agent.tracer.getSegment()
@@ -118,11 +124,12 @@ test('respects record_content', (t, end) => {
118124
response: res
119125
})
120126
assert.equal(embeddingEvent.input, undefined)
127+
assert.equal(embeddingEvent['response.usage.total_tokens'], 65)
121128
end()
122129
})
123130
})
124131

125-
test('respects record_content', (t, end) => {
132+
test('respects record_content by recording content when true', (t, end) => {
126133
const { agent } = t.nr
127134
const req = {
128135
input: 'This is my test input',
@@ -144,6 +151,33 @@ test('respects record_content', (t, end) => {
144151
response: res
145152
})
146153
assert.equal(embeddingEvent['response.usage.total_tokens'], 65)
154+
assert.equal(embeddingEvent.input, req.input)
155+
end()
156+
})
157+
})
158+
159+
test('does not calculate tokens when no content exists', (t, end) => {
160+
const { agent } = t.nr
161+
const req = {
162+
model: 'gpt-3.5-turbo-0613'
163+
}
164+
165+
function cb(model, content) {
166+
return 65
167+
}
168+
169+
const api = helper.getAgentApi()
170+
api.setLlmTokenCountCallback(cb)
171+
helper.runInTransaction(agent, () => {
172+
const segment = agent.tracer.getSegment()
173+
const embeddingEvent = new LlmEmbedding({
174+
agent,
175+
segment,
176+
request: req,
177+
response: res
178+
})
179+
assert.equal(embeddingEvent['response.usage.total_tokens'], undefined)
180+
assert.equal(embeddingEvent.input, undefined)
147181
end()
148182
})
149183
})

test/versioned/openai/chat-completions-res-api.test.js

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ test('responses.create', async (t) => {
143143
})
144144

145145
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
146-
assertChatCompletionSummary({ tx, model, chatSummary, tokenUsage: true })
146+
assertChatCompletionSummary({ tx, model, chatSummary })
147147

148148
tx.end()
149149
end()
@@ -174,7 +174,7 @@ test('responses.create', async (t) => {
174174
})
175175

176176
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
177-
assertChatCompletionSummary({ tx, model, chatSummary, tokenUsage: true, singleInput: true })
177+
assertChatCompletionSummary({ tx, model, chatSummary, singleInput: true })
178178

179179
tx.end()
180180
end()
@@ -318,8 +318,7 @@ test('responses.create', async (t) => {
318318
const stream = await client.responses.create({
319319
stream: true,
320320
input: content,
321-
model: 'gpt-4',
322-
stream_options: { include_usage: true }
321+
model: 'gpt-4'
323322
})
324323

325324
let chunk = {}
@@ -347,11 +346,11 @@ test('responses.create', async (t) => {
347346
const { client, agent } = t.nr
348347
helper.runInTransaction(agent, async (tx) => {
349348
const content = 'Streamed response'
349+
const model = 'gpt-4-0613'
350350
const stream = await client.responses.create({
351351
stream: true,
352352
input: [{ role: 'user', content }, { role: 'user', content: 'What does 1 plus 1 equal?' }],
353-
model: 'gpt-4',
354-
stream_options: { include_usage: true }
353+
model,
355354
})
356355

357356
let chunk = {}
@@ -366,10 +365,12 @@ test('responses.create', async (t) => {
366365
tx,
367366
chatMsgs,
368367
id: 'resp_684886977be881928c9db234e14ae7d80f8976796514dff9',
369-
model: 'gpt-4-0613',
368+
model,
370369
resContent: res,
371370
reqContent: content
372371
})
372+
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
373+
assertChatCompletionSummary({ tx, model, chatSummary, streaming: true })
373374

374375
tx.end()
375376
end()
@@ -378,23 +379,27 @@ test('responses.create', async (t) => {
378379

379380
await t.test('should call the tokenCountCallback in streaming', (t, end) => {
380381
const { client, agent } = t.nr
382+
const model = 'gpt-4-0613'
381383
const promptContent = 'Streamed response'
382384
const promptContent2 = 'What does 1 plus 1 equal?'
385+
const promptTokens = 53
386+
const completionTokens = 11
383387
const res = 'Test stream'
384388
const api = helper.getAgentApi()
389+
// swap the token counts
385390
function cb(model, content) {
386391
// could be gpt-4 or gpt-4-0613
387392
assert.ok(model === 'gpt-4' || model === 'gpt-4-0613', 'should be gpt-4 or gpt-4-0613')
388393
if (content === promptContent + ' ' + promptContent2) {
389-
return 53
394+
return promptTokens
390395
} else if (content === res) {
391-
return 11
396+
return completionTokens
392397
}
393398
}
394399
api.setLlmTokenCountCallback(cb)
395400
helper.runInTransaction(agent, async (tx) => {
396401
const stream = await client.responses.create({
397-
model: 'gpt-4',
402+
model,
398403
input: [
399404
{ role: 'user', content: promptContent },
400405
{ role: 'user', content: promptContent2 }
@@ -410,14 +415,15 @@ test('responses.create', async (t) => {
410415
const events = agent.customEventAggregator.events.toArray()
411416
const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage')
412417
assertChatCompletionMessages({
413-
tokenUsage: true,
414418
tx,
415419
chatMsgs,
416420
id: 'resp_684886977be881928c9db234e14ae7d80f8976796514dff9',
417-
model: 'gpt-4-0613',
421+
model,
418422
resContent: res,
419423
reqContent: promptContent
420424
})
425+
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
426+
assertChatCompletionSummary({ tx, model, chatSummary, streaming: true, promptTokens, completionTokens })
421427

422428
tx.end()
423429
end()

test/versioned/openai/chat-completions.test.js

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,21 @@ test('chat.completions.create', async (t) => {
148148
})
149149

150150
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
151-
assertChatCompletionSummary({ tx, model, chatSummary, tokenUsage: true })
151+
assertChatCompletionSummary({ tx, model, chatSummary })
152152

153153
tx.end()
154154
end()
155155
})
156156
})
157157

158158
if (semver.gte(pkgVersion, '4.12.2')) {
159-
await t.test('should create span on successful chat completion stream create', { skip: semver.lt(pkgVersion, '4.12.2') }, (t, end) => {
159+
await t.test('should create span on successful chat completion stream create', (t, end) => {
160160
const { client, agent, host, port } = t.nr
161161
helper.runInTransaction(agent, async (tx) => {
162162
const content = 'Streamed response'
163163
const stream = await client.chat.completions.create({
164164
stream: true,
165-
messages: [{ role: 'user', content }],
166-
stream_options: { include_usage: true },
165+
messages: [{ role: 'user', content }]
167166
})
168167

169168
let chunk = {}
@@ -202,8 +201,7 @@ test('chat.completions.create', async (t) => {
202201
{ role: 'user', content },
203202
{ role: 'user', content: 'What does 1 plus 1 equal?' }
204203
],
205-
stream: true,
206-
stream_options: { include_usage: true },
204+
stream: true
207205
})
208206

209207
let res = ''
@@ -220,6 +218,55 @@ test('chat.completions.create', async (t) => {
220218
i++
221219
}
222220

221+
const events = agent.customEventAggregator.events.toArray()
222+
assert.equal(events.length, 4, 'should create a chat completion message and summary event')
223+
const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage')
224+
assertChatCompletionMessages({
225+
tx,
226+
chatMsgs,
227+
id: 'chatcmpl-8MzOfSMbLxEy70lYAolSwdCzfguQZ',
228+
model,
229+
resContent: res,
230+
reqContent: content,
231+
noTokenUsage: true
232+
})
233+
234+
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
235+
assertChatCompletionSummary({ tx, model, chatSummary, noUsageTokens: true })
236+
237+
tx.end()
238+
end()
239+
})
240+
})
241+
242+
await t.test('should assign usage information when `include_usage` exists in stream', (t, end) => {
243+
const { client, agent } = t.nr
244+
helper.runInTransaction(agent, async (tx) => {
245+
const content = 'Streamed response usage'
246+
const model = 'gpt-4'
247+
const stream = await client.chat.completions.create({
248+
stream: true,
249+
model,
250+
messages: [
251+
{ role: 'user', content },
252+
{ role: 'user', content: 'What does 1 plus 1 equal?' }
253+
],
254+
streaming_options: { include_usage: true }
255+
})
256+
257+
let chunk = {}
258+
let res = ''
259+
for await (chunk of stream) {
260+
if (!chunk.usage) {
261+
res += chunk.choices[0]?.delta?.content
262+
}
263+
}
264+
assert.equal(chunk.headers, undefined, 'should remove response headers from user result')
265+
assert.equal(chunk.choices[0].message.role, 'assistant')
266+
const expectedRes = responses.get(content)
267+
assert.equal(chunk.choices[0].message.content, expectedRes.streamData)
268+
assert.equal(chunk.choices[0].message.content, res)
269+
assert.deepEqual(chunk.usage, { prompt_tokens: 53, completion_tokens: 11, total_tokens: 64 })
223270
const events = agent.customEventAggregator.events.toArray()
224271
assert.equal(events.length, 4, 'should create a chat completion message and summary event')
225272
const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage')
@@ -244,15 +291,17 @@ test('chat.completions.create', async (t) => {
244291
const { client, agent } = t.nr
245292
const promptContent = 'Streamed response'
246293
const promptContent2 = 'What does 1 plus 1 equal?'
294+
const promptTokens = 11
295+
const completionTokens = 53
247296
let res = ''
248297
const expectedModel = 'gpt-4'
249298
const api = helper.getAgentApi()
250299
function cb(model, content) {
251300
assert.equal(model, expectedModel)
252301
if (content === promptContent + ' ' + promptContent2) {
253-
return 53
302+
return promptTokens
254303
} else if (content === res) {
255-
return 11
304+
return completionTokens
256305
}
257306
}
258307
api.setLlmTokenCountCallback(cb)
@@ -276,7 +325,6 @@ test('chat.completions.create', async (t) => {
276325
const events = agent.customEventAggregator.events.toArray()
277326
const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage')
278327
assertChatCompletionMessages({
279-
tokenUsage: true,
280328
tx,
281329
chatMsgs,
282330
id: 'chatcmpl-8MzOfSMbLxEy70lYAolSwdCzfguQZ',
@@ -285,6 +333,9 @@ test('chat.completions.create', async (t) => {
285333
reqContent: promptContent
286334
})
287335

336+
const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0]
337+
assertChatCompletionSummary({ tx, model: expectedModel, chatSummary, promptTokens, completionTokens })
338+
288339
tx.end()
289340
end()
290341
})

test/versioned/openai/common-chat-api.js

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ module.exports = {
1313
const { match } = require('../../lib/custom-assertions')
1414

1515
function assertChatCompletionMessages(
16-
{ tx, chatMsgs, id, model, reqContent, resContent, tokenUsage },
16+
{ tx, chatMsgs, id, model, reqContent, resContent, noTokenUsage },
1717
{ assert = require('node:assert') } = {}
1818
) {
1919
const [segment] = tx.trace.getChildren(tx.trace.root.id)
@@ -36,14 +36,14 @@ function assertChatCompletionMessages(
3636
expectedChatMsg.sequence = 0
3737
expectedChatMsg.id = `${id}-0`
3838
expectedChatMsg.content = reqContent
39-
if (tokenUsage) {
39+
if (!noTokenUsage) {
4040
expectedChatMsg.token_count = 0
4141
}
4242
} else if (msg[1].sequence === 1) {
4343
expectedChatMsg.sequence = 1
4444
expectedChatMsg.id = `${id}-1`
4545
expectedChatMsg.content = 'What does 1 plus 1 equal?'
46-
if (tokenUsage) {
46+
if (!noTokenUsage) {
4747
expectedChatMsg.token_count = 0
4848
}
4949
} else {
@@ -52,7 +52,7 @@ function assertChatCompletionMessages(
5252
expectedChatMsg.id = `${id}-2`
5353
expectedChatMsg.content = resContent
5454
expectedChatMsg.is_response = true
55-
if (tokenUsage) {
55+
if (!noTokenUsage) {
5656
expectedChatMsg.token_count = 0
5757
}
5858
}
@@ -63,7 +63,7 @@ function assertChatCompletionMessages(
6363
}
6464

6565
function assertChatCompletionSummary(
66-
{ tx, model, chatSummary, error = false },
66+
{ tx, model, chatSummary, error = false, promptTokens = 53, completionTokens = 11, totalTokens = 64, noUsageTokens = false },
6767
{ assert = require('node:assert') } = {}
6868
) {
6969
const [segment] = tx.trace.getChildren(tx.trace.root.id)
@@ -90,6 +90,12 @@ function assertChatCompletionSummary(
9090
error
9191
}
9292

93+
if (!(error || noUsageTokens)) {
94+
expectedChatSummary['response.usage.prompt_tokens'] = promptTokens
95+
expectedChatSummary['response.usage.completion_tokens'] = completionTokens
96+
expectedChatSummary['response.usage.total_tokens'] = totalTokens
97+
}
98+
9399
assert.equal(chatSummary[0].type, 'LlmChatCompletionSummary')
94100
match(chatSummary[1], expectedChatSummary, { assert })
95101
}

0 commit comments

Comments
 (0)