Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 77 additions & 31 deletions packages/durable-iterator/src/client/plugin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { StandardRPCHandler } from '@orpc/server/standard'
import { isAsyncIteratorObject, sleep } from '@orpc/shared'
import { decodeRequestMessage, encodeResponseMessage, MessageType } from '@orpc/standard-server-peer'
import { WebSocket as ReconnectableWebSocket } from 'partysocket'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { DURABLE_ITERATOR_TOKEN_PARAM } from '../consts'
import { DurableIteratorError } from '../error'
import { DurableIterator } from '../iterator'
Expand All @@ -12,6 +13,8 @@ import { parseDurableIteratorToken } from '../schemas'
import { getClientDurableIteratorToken } from './iterator'
import { DurableIteratorLinkPlugin } from './plugin'

const realSetTimeout = globalThis.setTimeout

vi.mock('partysocket', () => {
return {
WebSocket: vi.fn(() => ({
Expand All @@ -34,6 +37,7 @@ describe('durableIteratorLinkPlugin', async () => {
() => new DurableIterator<any, any>('some-room', { signingKey: 'signing-key', tags: ['tag'] }).rpc('getUser', 'sendMessage'),
)
const refreshTokenBeforeExpireInSeconds = vi.fn(() => Number.NaN)
const refreshTokenDelayInSeconds = vi.fn(() => 2)

const handler = new StandardRPCHandler({
durableIterator: os.handler(durableIteratorHandler),
Expand Down Expand Up @@ -61,6 +65,7 @@ describe('durableIteratorLinkPlugin', async () => {
new DurableIteratorLinkPlugin({
url: 'ws://localhost',
refreshTokenBeforeExpireInSeconds,
refreshTokenDelayInSeconds,
}),
],
})
Expand Down Expand Up @@ -143,8 +148,18 @@ describe('durableIteratorLinkPlugin', async () => {
})

describe('refresh expired token', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.setSystemTime(new Date('2022-01-01T00:00:00.000Z'))
})
afterEach(async () => {
await new Promise(resolve => realSetTimeout(resolve, 1000)) // await for all promises resolved
expect(vi.getTimerCount()).toBe(0) // every is cleanup
vi.restoreAllMocks()
})
Comment thread
dinwwwh marked this conversation as resolved.

it('works', async () => {
refreshTokenBeforeExpireInSeconds.mockImplementation(() => 8)
refreshTokenBeforeExpireInSeconds.mockImplementation(() => 9)
durableIteratorHandler.mockImplementation(
() => new DurableIterator<any, any>('some-room', {
signingKey: 'signing-key',
Expand Down Expand Up @@ -174,6 +189,8 @@ describe('durableIteratorLinkPlugin', async () => {
const output = await outputPromise
expect(output).toSatisfy(isAsyncIteratorObject)

expect(vi.getTimerCount()).toBe(1) // refresh token is enabled

const urlProvider = vi.mocked(ReconnectableWebSocket).mock.calls[0]![0] as any
const ws = vi.mocked(ReconnectableWebSocket).mock.results[0]!.value
ws.send.mockClear()
Expand All @@ -183,22 +200,46 @@ describe('durableIteratorLinkPlugin', async () => {
expect(token1).toBeTypeOf('string')
expect(durableIteratorHandler).toHaveBeenCalledTimes(1)

await sleep(1000)
vi.advanceTimersByTime(500) // not expired yet
expect(vi.getTimerCount()).toBe(1) // no refresh executed
expect(await urlProvider()).toEqual(url1)
expect(getClientDurableIteratorToken(output)).toEqual(token1)
expect(durableIteratorHandler).toHaveBeenCalledTimes(1) // not expired yet
expect(durableIteratorHandler).toHaveBeenCalledTimes(1)

await sleep(1000)
expect(await urlProvider()).not.toEqual(url1)
expect(getClientDurableIteratorToken(output)).not.toEqual(token1)
vi.advanceTimersByTime(500) // expired
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(r => realSetTimeout(r, 10)) // wait for token refresh promise
const url2 = await urlProvider()
expect(url2).not.toEqual(url1)
const token2 = getClientDurableIteratorToken(output)
expect(token2).not.toEqual(token1)
expect(durableIteratorHandler).toHaveBeenCalledTimes(2)
expect(ws.send).toHaveBeenCalledTimes(1) // send set token request to durable iterator
expect(vi.getTimerCount()).toBe(1) // new timer started

vi.advanceTimersByTime(2000) // wait next retry + refreshTokenDelayInSeconds delay
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(r => realSetTimeout(r, 10)) // wait for token refresh promise
const url3 = await urlProvider()
expect(url3).not.toEqual(url1)
expect(url3).not.toEqual(url2)
const token3 = getClientDurableIteratorToken(output)
expect(token3).not.toEqual(token1)
expect(token3).not.toEqual(token2)
expect(durableIteratorHandler).toHaveBeenCalledTimes(3)
expect(ws.send).toHaveBeenCalledTimes(2) // send set token request to durable iterator
expect(vi.getTimerCount()).toBe(1) // new timer started

expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledTimes(2)
expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledTimes(3)
expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledWith(
parseDurableIteratorToken(new URL(url1).searchParams.get(DURABLE_ITERATOR_TOKEN_PARAM)!),
expect.objectContaining({ path: ['durableIterator'] }),
)
expect(refreshTokenDelayInSeconds).toHaveBeenCalledTimes(3)
expect(refreshTokenDelayInSeconds).toHaveBeenCalledWith(
parseDurableIteratorToken(new URL(url1).searchParams.get(DURABLE_ITERATOR_TOKEN_PARAM)!),
expect.objectContaining({ path: ['durableIterator'] }),
)

await output.return() // cleanup
})
Expand Down Expand Up @@ -234,28 +275,12 @@ describe('durableIteratorLinkPlugin', async () => {
const output = await outputPromise
expect(output).toSatisfy(isAsyncIteratorObject)

const urlProvider = vi.mocked(ReconnectableWebSocket).mock.calls[0]![0] as any

const url1 = await urlProvider()
expect(durableIteratorHandler).toHaveBeenCalledTimes(1)
expect(await urlProvider()).toEqual(url1)
expect(durableIteratorHandler).toHaveBeenCalledTimes(1) // not expired yet

await sleep(1000)
const url2 = await urlProvider()
expect(url1).toEqual(url2)
expect(durableIteratorHandler).toHaveBeenCalledTimes(1) // no refresh happened

expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledTimes(1)
expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledWith(
parseDurableIteratorToken(new URL(url1).searchParams.get(DURABLE_ITERATOR_TOKEN_PARAM)!),
expect.objectContaining({ path: ['durableIterator'] }),
)
expect(vi.getTimerCount()).toBe(0) // refresh token is disabled

await output.return() // cleanup
})

it('if refresh token is invalid', { timeout: 10000 }, async () => {
it('if refresh token is invalid', async () => {
refreshTokenBeforeExpireInSeconds.mockImplementation(() => 9)
durableIteratorHandler.mockImplementationOnce(
() => new DurableIterator<any, any>('some-room', {
Expand Down Expand Up @@ -286,20 +311,28 @@ describe('durableIteratorLinkPlugin', async () => {
const output = await outputPromise
expect(output).toSatisfy(isAsyncIteratorObject)

expect(vi.getTimerCount()).toBe(1) // refresh token is enabled

const urlProvider = vi.mocked(ReconnectableWebSocket).mock.calls[0]![0] as any

const url = await urlProvider()
expect(durableIteratorHandler).toHaveBeenCalledTimes(1)

durableIteratorHandler.mockResolvedValueOnce('invalid-token' as any)
await sleep(1000) // wait first retry trigger
vi.advanceTimersByTime(1000) // wait first retry trigger
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh promise
await expect(urlProvider()).resolves.toBe(url) // not change url because new token is invalid
expect(durableIteratorHandler).toHaveBeenCalledTimes(2)
expect(vi.getTimerCount()).toBe(1) // timer created by retry helper

durableIteratorHandler.mockResolvedValueOnce({} as any)
await sleep(2000) // wait next retry
vi.advanceTimersByTime(2000) // wait next retry
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh promise
await expect(urlProvider()).resolves.toBe(url) // not change url because new token is invalid
expect(durableIteratorHandler).toHaveBeenCalledTimes(3)
expect(vi.getTimerCount()).toBe(1) // timer created by retry helper

// only called once, because it still retrying after invalid token
expect(refreshTokenBeforeExpireInSeconds).toHaveBeenCalledTimes(1)
Expand All @@ -315,10 +348,13 @@ describe('durableIteratorLinkPlugin', async () => {
await sleep(2000)
return {} as any
})
await sleep(2000) // wait next retry
vi.advanceTimersByTime(2000) // wait next retry
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh trigger
await output.return() // cleanup

await sleep(2000)
vi.advanceTimersByTime(2000) // wait handler throw
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh reject
expect(unhandledRejectionHandler).toHaveBeenCalledTimes(1)
expect(unhandledRejectionHandler.mock.calls[0]![0]).toEqual(
Comment on lines +351 to 359
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Ensure the unhandledRejection listener is removed.

After the final assertions, call process.off('unhandledRejection', unhandledRejectionHandler) to prevent leaking the listener to subsequent tests.

🤖 Prompt for AI Agents
In packages/durable-iterator/src/client/plugin.test.ts around lines 351 to 359,
the test adds an unhandledRejection listener but never removes it, leaking the
listener into subsequent tests; after the final assertions add a call to remove
the listener by invoking process.off('unhandledRejection',
unhandledRejectionHandler) (or process.removeListener) to unregister the handler
and prevent test leakage.

new DurableIteratorError(`Expected valid token for procedure durableIterator`),
Expand Down Expand Up @@ -356,14 +392,19 @@ describe('durableIteratorLinkPlugin', async () => {
const output = await outputPromise
expect(output).toSatisfy(isAsyncIteratorObject)

expect(vi.getTimerCount()).toBe(1) // refresh token is enabled

durableIteratorHandler.mockResolvedValueOnce(
new DurableIterator<any, any>('a-different-channel', { signingKey: 'signing-key' }).rpc('getUser', 'sendMessage') as any,
)

await sleep(1000)
vi.advanceTimersByTime(1000)
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh promise
const ws = vi.mocked(ReconnectableWebSocket).mock.results[0]!.value
expect(ws.reconnect).toHaveBeenCalledTimes(1)
expect(await (ReconnectableWebSocket as any).mock.calls[0]![0]()).toContain('a-different-channel')
expect(vi.getTimerCount()).toBe(1) // new refresh token timer created

await output.return() // cleanup
})
Expand Down Expand Up @@ -400,6 +441,8 @@ describe('durableIteratorLinkPlugin', async () => {
const output = await outputPromise
expect(output).toSatisfy(isAsyncIteratorObject)

expect(vi.getTimerCount()).toBe(1) // refresh token is enabled

durableIteratorHandler.mockImplementationOnce(
() => new DurableIterator<any, any>('some-room', {
tags: ['a-different-tag'],
Expand All @@ -408,10 +451,13 @@ describe('durableIteratorLinkPlugin', async () => {
}) as any,
)

await sleep(1000)
vi.advanceTimersByTime(1000)
expect(vi.getTimerCount()).toBe(0) // refresh token executed
await new Promise(resolve => realSetTimeout(resolve, 10)) // wait for token refresh promise
const ws = vi.mocked(ReconnectableWebSocket).mock.results[0]!.value
expect(ws.reconnect).toHaveBeenCalledTimes(1)
expect(await (ReconnectableWebSocket as any).mock.calls[0]![0]()).toContain('a-different-tag')
expect(vi.getTimerCount()).toBe(1) // new refresh token timer created

await output.return() // cleanup
})
Expand Down
76 changes: 45 additions & 31 deletions packages/durable-iterator/src/client/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ export interface DurableIteratorLinkPluginOptions<T extends ClientContext> exten
* @default NaN (disabled)
*/
refreshTokenBeforeExpireInSeconds?: Value<Promisable<number>, [tokenPayload: DurableIteratorTokenPayload, options: StandardLinkInterceptorOptions<T>]>

/**
* Minimum delay between token refresh attempts.
*
* @default 2 (seconds)
*/
refreshTokenDelayInSeconds?: Value<Promisable<number>, [tokenPayload: DurableIteratorTokenPayload, options: StandardLinkInterceptorOptions<T>]>
}

/**
Expand All @@ -67,12 +74,14 @@ export class DurableIteratorLinkPlugin<T extends ClientContext> implements Stand
private readonly url: DurableIteratorLinkPluginOptions<T>['url']
private readonly createId: Exclude<DurableIteratorLinkPluginOptions<T>['createId'], undefined>
private readonly refreshTokenBeforeExpireInSeconds: Exclude<DurableIteratorLinkPluginOptions<T>['refreshTokenBeforeExpireInSeconds'], undefined>
private readonly refreshTokenDelayInSeconds: Exclude<DurableIteratorLinkPluginOptions<T>['refreshTokenDelayInSeconds'], undefined>
private readonly linkOptions: Omit<RPCLinkOptions<object>, 'websocket'>

constructor({ url, refreshTokenBeforeExpireInSeconds, ...options }: DurableIteratorLinkPluginOptions<T>) {
constructor({ url, refreshTokenBeforeExpireInSeconds, refreshTokenDelayInSeconds, ...options }: DurableIteratorLinkPluginOptions<T>) {
this.url = url
this.createId = fallback(options.createId, () => crypto.randomUUID())
this.refreshTokenBeforeExpireInSeconds = fallback(refreshTokenBeforeExpireInSeconds, Number.NaN)
this.refreshTokenDelayInSeconds = fallback(refreshTokenDelayInSeconds, 2)
this.linkOptions = options
}

Expand Down Expand Up @@ -127,56 +136,61 @@ export class DurableIteratorLinkPlugin<T extends ClientContext> implements Stand
let refreshTokenBeforeExpireTimeoutId: ReturnType<typeof setTimeout> | undefined
const refreshTokenBeforeExpire = async () => {
const beforeSeconds = await value(this.refreshTokenBeforeExpireInSeconds, tokenAndPayload.payload, options)
const delayMilliseconds = await value(this.refreshTokenDelayInSeconds, tokenAndPayload.payload, options) * 1000

// stop refreshing if already finished
if (isFinished || !Number.isFinite(beforeSeconds)) {
return
}

const nowInSeconds = Math.floor(Date.now() / 1000)

refreshTokenBeforeExpireTimeoutId = setTimeout(async () => {
refreshTokenBeforeExpireTimeoutId = setTimeout(
async () => {
// retry until success or finished
const newTokenAndPayload = await retry({ times: Number.POSITIVE_INFINITY, delay: 2000 }, async (exit) => {
try {
const output = await next()
return this.validateToken(output, options.path)
}
catch (err) {
if (isFinished) {
exit(err)
const newTokenAndPayload = await retry({ times: Number.POSITIVE_INFINITY, delay: delayMilliseconds }, async (exit) => {
Comment thread
dinwwwh marked this conversation as resolved.
try {
const output = await next()
return this.validateToken(output, options.path)
}
catch (err) {
if (isFinished) {
exit(err)
}

throw err
}
})
throw err
}
})

const canProactivelyUpdateToken
= newTokenAndPayload.payload.chn === tokenAndPayload.payload.chn
&& stringifyJSON(newTokenAndPayload.payload.tags) === stringifyJSON(tokenAndPayload.payload.tags)
const canProactivelyUpdateToken
= newTokenAndPayload.payload.chn === tokenAndPayload.payload.chn
&& stringifyJSON(newTokenAndPayload.payload.tags) === stringifyJSON(tokenAndPayload.payload.tags)

tokenAndPayload = newTokenAndPayload
await refreshTokenBeforeExpire() // recursively call
tokenAndPayload = newTokenAndPayload
await refreshTokenBeforeExpire() // recursively call

/**
* The next refresh cycle doesn't depend on the logic below,
* so we place it last to avoid interfering with recursion.
*/
if (canProactivelyUpdateToken) {
/**
* The next refresh cycle doesn't depend on the logic below,
* so we place it last to avoid interfering with recursion.
*/
if (canProactivelyUpdateToken) {
/**
* Proactively update the token before expiration
* to avoid reconnecting when the old token expires.
*/
await durableClient.updateToken({ token: tokenAndPayload.token })
}
else {
await durableClient.updateToken({ token: tokenAndPayload.token })
}
else {
/**
* Proactive update requires the same channel and tags.
* If they differ, we must reconnect instead to make new token effective.
*/
websocket.reconnect()
}
}, (tokenAndPayload.payload.exp - nowInSeconds - beforeSeconds) * 1000)
websocket.reconnect()
}
},
Math.max(
refreshTokenBeforeExpireTimeoutId === undefined ? 0 : delayMilliseconds,
((tokenAndPayload.payload.exp - beforeSeconds) * 1000) - Date.now(),
),
)
}
refreshTokenBeforeExpire()

Expand Down
2 changes: 0 additions & 2 deletions packages/durable-iterator/src/iterator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ export interface DurableIteratorOptions<
* The methods that are allowed to be called remotely.
*
* @warning Please use .rpc method to set this field in case ts complains about value you pass
*
* @default []
*/
rpc?: readonly RPC[]
}
Expand Down
Loading