Skip to content

Commit 3bec7fc

Browse files
authored
feat(ai-sdk): disable validation at oRPC level in createTool to avoid double validation (#1166)
Refactors validation handling in tool creation by setting input/output validation indices to NaN, replacing the previous proxy-based approach. This prevents validation from occurring twice - once at the tool level and once at the procedure call level - which was causing issues when schemas transform data into different shapes. Includes test coverage for the new validation disabling behavior and updates related tRPC integration to use the same approach. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Tool execution now forwards abort/cancel signals and accepts calling options for finer control. * Added support to disable input/output validation when desired for flexible handling. * **Documentation** * Simplified docs and removed a redundant validation warning for clearer guidance. * **Tests** * Expanded tests covering abort-signal propagation and disabled-validation scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 6547b90 commit 3bec7fc

File tree

6 files changed

+72
-42
lines changed

6 files changed

+72
-42
lines changed

apps/content/docs/integrations/ai-sdk.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ const getWeatherTool = implementTool(getWeatherContract, {
160160
location,
161161
temperature: 72 + Math.floor(Math.random() * 21) - 10,
162162
}),
163+
// ...add any additional configuration or overrides here
163164
})
164165
```
165166

@@ -210,13 +211,10 @@ const getWeatherProcedure = base
210211

211212
const getWeatherTool = createTool(getWeatherProcedure, {
212213
context: {}, // provide initial context if needed
214+
// ...add any additional configuration or overrides here
213215
})
214216
```
215217

216218
::: warning
217219
The `createTool` helper requires a procedure with an `input` schema defined
218220
:::
219-
220-
::: warning
221-
Validation occurs twice (once for the tool, once for the procedure call). So validation may fail if `inputSchema` or `outputSchema` transform the data into different shapes.
222-
:::

packages/ai-sdk/src/tool.test.ts

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ describe('implementTool', () => {
7171
})
7272

7373
describe('createTool', () => {
74+
const abortSignal = (new AbortController()).signal
7475
const base = os.$meta<AiSdkToolMeta>({})
7576

7677
const inputSchema = z.object({
@@ -80,7 +81,7 @@ describe('createTool', () => {
8081
greeting: z.string().describe('Greeting message'),
8182
})
8283

83-
it('can create a tool', () => {
84+
it('can create a tool', async () => {
8485
const handler = vi.fn(async ({ input }) => {
8586
return {
8687
greeting: `Hello, ${input.name}!`,
@@ -103,13 +104,27 @@ describe('createTool', () => {
103104
expect(tool.outputSchema).toBe(outputSchema)
104105
expect(tool.description).toBe('Greet a person')
105106

106-
return expect(
107-
(tool as any).execute({ name: 'Alice' }),
108-
).resolves.toEqual({ greeting: 'Hello, Alice!' })
107+
await expect((tool as any).execute({ name: 'Alice' }, { abortSignal })).resolves.toEqual({ greeting: 'Hello, Alice!' })
109108

110109
expect(handler).toHaveBeenCalledWith(expect.objectContaining({
110+
signal: abortSignal,
111111
input: { name: 'Alice' },
112112
context: { authToken: 'auth-token' },
113113
}))
114114
})
115+
116+
it('disable validation at oRPC level to avoid twice times validation', async () => {
117+
const procedure
118+
= base
119+
.route({
120+
summary: 'Greet a person',
121+
})
122+
.input(inputSchema)
123+
.output(outputSchema)
124+
.handler(({ input }) => input as any)
125+
126+
const tool = createTool(procedure)
127+
128+
await expect(tool.execute?.('invalid' as any, { abortSignal } as any)).resolves.toEqual('invalid')
129+
})
115130
})

packages/ai-sdk/src/tool.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import type { ClientOptions } from '@orpc/client'
22
import type { AnySchema, ContractProcedure, ErrorMap, InferSchemaInput, InferSchemaOutput, Meta, Schema } from '@orpc/contract'
3-
import type { Context, CreateProcedureClientOptions, Procedure } from '@orpc/server'
3+
import type { Context, CreateProcedureClientOptions } from '@orpc/server'
44
import type { MaybeOptionalOptions, SetOptional } from '@orpc/shared'
55
import type { Tool } from 'ai'
6-
import { call } from '@orpc/server'
6+
import { call, Procedure } from '@orpc/server'
77
import { resolveMaybeOptionalOptions } from '@orpc/shared'
88
import { tool } from 'ai'
99

@@ -88,8 +88,6 @@ export function implementTool<TOutInput, TInOutput>(
8888
* by leveraging existing procedure definitions.
8989
*
9090
* @warning Requires a contract with an `input` schema defined.
91-
* @warning Validation occurs twice (once for the tool, once for the procedure call).
92-
* So validation may fail if inputSchema or outputSchema transform the data into different shapes.
9391
*
9492
* @example
9593
* ```ts
@@ -147,9 +145,19 @@ export function createTool<
147145
const options = resolveMaybeOptionalOptions(rest)
148146

149147
return implementTool(procedure, {
150-
execute: (input: InferSchemaOutput<TInputSchema>) => {
151-
return call(procedure, input as InferSchemaInput<TInputSchema>, options)
152-
},
148+
execute: ((input, callingOptions) => {
149+
const disabledValidation = new Procedure({
150+
...procedure['~orpc'],
151+
inputValidationIndex: Number.NaN, // disable input validation
152+
outputValidationIndex: Number.NaN, // disable output validation
153+
})
154+
155+
return call(
156+
disabledValidation,
157+
input as InferSchemaInput<TInputSchema>,
158+
{ signal: callingOptions.abortSignal, ...options },
159+
) as Promise<InferSchemaInput<TOutputSchema>>
160+
}) satisfies (Tool<InferSchemaOutput<TInputSchema>, InferSchemaInput<TOutputSchema>>['execute']),
153161
...options,
154162
} as any)
155163
}

packages/server/src/procedure-client.test.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,21 @@ it('has helper `output` in meta', async () => {
652652

653653
expect(preMid1).toReturnWith(Promise.resolve({ output: { val: '99990' }, context: {} }))
654654
})
655+
656+
it('support disable input/output validation by setting validation index to NaN', async () => {
657+
const procedure = new Procedure({
658+
inputSchema: schema,
659+
outputSchema: schema,
660+
errorMap: {},
661+
route: {},
662+
meta: {},
663+
handler: ({ input }) => input,
664+
middlewares: [],
665+
inputValidationIndex: Number.NaN,
666+
outputValidationIndex: Number.NaN,
667+
})
668+
669+
const client = createProcedureClient(procedure)
670+
671+
await expect(client('invalid' as any)).resolves.toEqual('invalid')
672+
})

packages/trpc/src/to-orpc-router.test.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { call, createRouterClient, getEventMeta, isLazy, isProcedure, ORPCError, unlazy } from '@orpc/server'
1+
import { call, createRouterClient, getEventMeta, isLazy, isProcedure, ORPCError, Procedure, unlazy } from '@orpc/server'
22
import { isAsyncIteratorObject } from '@orpc/shared'
33
import { tracked, TRPCError } from '@trpc/server'
44
import * as z from 'zod'
@@ -31,15 +31,22 @@ describe('toORPCRouter', async () => {
3131
expect(await unlazy(orpcRouter.lazy.lazy.throw)).toEqual({ default: expect.toSatisfy(isProcedure) })
3232
})
3333

34-
it('with disabled input/output', async () => {
34+
it('with input/output schema and validation happen inside handler only', async () => {
3535
expect((orpcRouter as any).ping['~orpc'].inputSchema['~standard'].vendor).toBe('zod')
3636
expect((orpcRouter as any).ping['~orpc'].inputSchema._def).toBe(inputSchema._def)
37+
expect((orpcRouter as any).ping['~orpc'].inputValidationIndex).toBe(Number.NaN) // input validation is disabled
38+
3739
expect((orpcRouter as any).ping['~orpc'].outputSchema['~standard'].vendor).toBe('zod')
3840
expect((orpcRouter as any).ping['~orpc'].outputSchema._def).toBe(outputSchema._def)
41+
expect((orpcRouter as any).ping['~orpc'].outputValidationIndex).toBe(Number.NaN) // output validation is disabled
42+
43+
const withoutHandlerProcedure = new Procedure({
44+
...(orpcRouter as any).ping['~orpc'],
45+
handler: async ({ input }) => input,
46+
})
3947

40-
const invalidValue = 'INVALID'
41-
expect((orpcRouter as any).ping['~orpc'].inputSchema['~standard'].validate(invalidValue)).toEqual({ value: invalidValue })
42-
expect((orpcRouter as any).ping['~orpc'].outputSchema['~standard'].validate(invalidValue)).toEqual({ value: invalidValue })
48+
await expect(call(withoutHandlerProcedure, 'invalid')).resolves.toEqual('invalid') // validation not happen at oRPC level
49+
await expect(call((orpcRouter as any).ping, 'invalid')).rejects.toThrow('Invalid input') // validation happen at tRPC level
4350
})
4451

4552
it('meta/route', async () => {

packages/trpc/src/to-orpc-router.ts

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ function toORPCProcedure(procedure: AnyProcedure) {
8989
return new ORPC.Procedure({
9090
errorMap: {},
9191
meta: procedure._def.meta ?? {},
92-
inputValidationIndex: 0,
93-
outputValidationIndex: 0,
9492
route: get(procedure._def.meta, ['route']) ?? {},
9593
middlewares: [],
96-
inputSchema: toDisabledStandardSchema(procedure._def.inputs.at(-1)),
97-
outputSchema: toDisabledStandardSchema((procedure as any)._def.output),
94+
inputSchema: toStandardSchema(procedure._def.inputs.at(-1)),
95+
outputSchema: toStandardSchema((procedure._def as any).output),
96+
inputValidationIndex: Number.NaN, // disable input validation
97+
outputValidationIndex: Number.NaN, // disable output validation
9898
handler: async ({ context, signal, path, input, lastEventId }) => {
9999
try {
100100
const trpcInput = lastEventId !== undefined && (input === undefined || isObject(input))
@@ -151,26 +151,10 @@ function toORPCProcedure(procedure: AnyProcedure) {
151151
* Wraps a TRPC schema to disable validation in the ORPC context.
152152
* This is necessary because tRPC procedure calling already validates the input/output,
153153
*/
154-
function toDisabledStandardSchema(schema: undefined | Parser): undefined | ORPC.Schema<unknown, unknown> {
154+
function toStandardSchema(schema: undefined | Parser): undefined | ORPC.Schema<unknown, unknown> {
155155
if (!isTypescriptObject(schema) || !('~standard' in schema) || !isTypescriptObject(schema['~standard'])) {
156156
return undefined
157157
}
158158

159-
return new Proxy(schema as any, {
160-
get: (target, prop) => {
161-
if (prop === '~standard') {
162-
return new Proxy(target['~standard'], {
163-
get: (target, prop) => {
164-
if (prop === 'validate') {
165-
return (value: any) => ({ value })
166-
}
167-
168-
return Reflect.get(target, prop, target)
169-
},
170-
})
171-
}
172-
173-
return Reflect.get(target, prop, target)
174-
},
175-
})
159+
return schema as any
176160
}

0 commit comments

Comments
 (0)