Skip to content

Commit a7522a6

Browse files
committed
fix(ai): return transformed array output elements
1 parent 43ad34c commit a7522a6

4 files changed

Lines changed: 50 additions & 2 deletions

File tree

packages/ai/src/generate-object/generate-object.test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,33 @@ describe('generateObject', () => {
906906
}
907907
`);
908908
});
909+
910+
it('should return transformed element values', async () => {
911+
const model = new MockLanguageModelV4({
912+
doGenerate: {
913+
...dummyResponseValues,
914+
content: [
915+
{
916+
type: 'text',
917+
text: JSON.stringify({
918+
elements: [{ content: 'element 1' }],
919+
}),
920+
},
921+
],
922+
},
923+
});
924+
925+
const result = await generateObject({
926+
model,
927+
schema: z.object({
928+
content: z.string().transform(value => value.toUpperCase()),
929+
}),
930+
output: 'array',
931+
prompt: 'prompt',
932+
});
933+
934+
expect(result.object).toEqual([{ content: 'ELEMENT 1' }]);
935+
});
909936
});
910937

911938
describe('output = "enum"', () => {

packages/ai/src/generate-object/output-strategy.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,19 @@ const arrayOutputStrategy = <ELEMENT>(
239239
}
240240

241241
const inputArray = value.elements as Array<JSONObject>;
242+
const resultArray: Array<ELEMENT> = [];
242243

243244
// check that each element in the array is of the correct type:
244245
for (const element of inputArray) {
245246
const result = await safeValidateTypes({ value: element, schema });
246247
if (!result.success) {
247248
return result;
248249
}
250+
251+
resultArray.push(result.value);
249252
}
250253

251-
return { success: true, value: inputArray as Array<ELEMENT> };
254+
return { success: true, value: resultArray };
252255
},
253256

254257
createElementStream(

packages/ai/src/generate-text/output.test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,21 @@ describe('Output.array', () => {
322322
expect(result).toStrictEqual([{ content: 'test' }]);
323323
});
324324

325+
it('should return transformed element values', async () => {
326+
const arrayWithTransform = array({
327+
element: z.object({
328+
content: z.string().transform(value => value.toUpperCase()),
329+
}),
330+
});
331+
332+
const result = await arrayWithTransform.parseCompleteOutput(
333+
{ text: `{ "elements": [{ "content": "test" }] }` },
334+
context,
335+
);
336+
337+
expect(result).toStrictEqual([{ content: 'TEST' }]);
338+
});
339+
325340
it('should throw NoObjectGeneratedError when parsing fails', async () => {
326341
try {
327342
await array1.parseCompleteOutput({ text: '{ broken json' }, context);

packages/ai/src/generate-text/output.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ export const array = <ELEMENT>({
278278
});
279279
}
280280

281+
const parsedElements: Array<ELEMENT> = [];
281282
for (const element of outerValue.elements) {
282283
const validationResult = await safeValidateTypes({
283284
value: element,
@@ -294,9 +295,11 @@ export const array = <ELEMENT>({
294295
finishReason: context.finishReason,
295296
});
296297
}
298+
299+
parsedElements.push(validationResult.value);
297300
}
298301

299-
return outerValue.elements as Array<ELEMENT>;
302+
return parsedElements;
300303
},
301304

302305
async parsePartialOutput({ text }: { text: string }) {

0 commit comments

Comments
 (0)