diff --git a/packages/zudoku/src/lib/authentication/hook.ts b/packages/zudoku/src/lib/authentication/hook.ts index 6b3280140..ed11b52b9 100644 --- a/packages/zudoku/src/lib/authentication/hook.ts +++ b/packages/zudoku/src/lib/authentication/hook.ts @@ -13,18 +13,31 @@ export type UseAuthReturn = ReturnType; */ export const useRefreshUserProfile = ({ refetchOnWindowFocus, + refetchOnMount, }: { refetchOnWindowFocus?: boolean | "always"; + refetchOnMount?: boolean | "always"; } = {}) => { const { authentication } = useZudoku(); + const profile = useAuthState((s) => s.profile); + const profileFetchedAt = useAuthState((s) => s.profileFetchedAt); const isAuthEnabled = typeof authentication !== "undefined"; return useQuery({ refetchOnWindowFocus, + refetchOnMount, + staleTime: 1000 * 60 * 5, + gcTime: 1000 * 60 * 10, queryKey: ["refresh-user-profile"], enabled: isAuthEnabled && typeof authentication?.refreshUserProfile === "function", - queryFn: () => authentication?.refreshUserProfile?.(), + queryFn: async () => { + const result = await authentication?.refreshUserProfile?.(); + useAuthState.setState({ profileFetchedAt: Date.now() }); + return result; + }, + initialData: profile ? true : undefined, + initialDataUpdatedAt: profileFetchedAt ?? undefined, }); }; @@ -34,8 +47,11 @@ export const useVerifiedEmail = () => { const navigate = useNavigate(); const isAuthEnabled = typeof authentication !== "undefined"; + const isUnverified = authState.profile?.emailVerified === false; + const { refetch: refreshUserProfile } = useRefreshUserProfile({ - refetchOnWindowFocus: "always", + refetchOnWindowFocus: isUnverified ? "always" : true, + refetchOnMount: isUnverified ? "always" : undefined, }); return { diff --git a/packages/zudoku/src/lib/authentication/providers/clerk.tsx b/packages/zudoku/src/lib/authentication/providers/clerk.tsx index 91db91fb3..fd19ba368 100644 --- a/packages/zudoku/src/lib/authentication/providers/clerk.tsx +++ b/packages/zudoku/src/lib/authentication/providers/clerk.tsx @@ -127,6 +127,7 @@ const clerkAuth: AuthenticationProviderInitializer< isAuthenticated: true, isPending: false, profile, + profileFetchedAt: Date.now(), providerData: { type: "clerk", user: clerk.session?.user, diff --git a/packages/zudoku/src/lib/authentication/providers/firebase.tsx b/packages/zudoku/src/lib/authentication/providers/firebase.tsx index 15d8cd21b..7b210dd51 100644 --- a/packages/zudoku/src/lib/authentication/providers/firebase.tsx +++ b/packages/zudoku/src/lib/authentication/providers/firebase.tsx @@ -365,6 +365,7 @@ class FirebaseAuthenticationProvider isAuthenticated: false, isPending: false, profile: undefined, + profileFetchedAt: null, providerData: undefined, }); diff --git a/packages/zudoku/src/lib/authentication/providers/openid.test.ts b/packages/zudoku/src/lib/authentication/providers/openid.test.ts index e63fcc53b..b9b7558b5 100644 --- a/packages/zudoku/src/lib/authentication/providers/openid.test.ts +++ b/packages/zudoku/src/lib/authentication/providers/openid.test.ts @@ -471,6 +471,67 @@ describe("OpenIDAuthenticationProvider emailVerified", () => { }); }); + describe("discovery caching", () => { + const setupAuthenticated = () => { + useAuthState.setState({ + isAuthenticated: true, + isPending: false, + profile: { + sub: "user-1", + email: "user@example.com", + emailVerified: false, + name: "Test", + pictureUrl: undefined, + }, + providerData: { + type: "openid", + accessToken: FAKE_ACCESS_TOKEN, + expiresOn: new Date(Date.now() + 3600_000), + tokenType: "bearer", + claims: undefined, + } satisfies OpenIdProviderData, + }); + + vi.mocked(oauth.userInfoRequest).mockImplementation(() => + Promise.resolve( + Response.json({ sub: "user-1", email: "user@example.com" }), + ), + ); + }; + + test("retries discovery after a failed request", async () => { + vi.mocked(oauth.discoveryRequest) + .mockReset() + .mockRejectedValueOnce(new Error("network down")) + .mockImplementation(() => Promise.resolve(new Response())); + + const provider = createProvider(); + setupAuthenticated(); + + await expect(provider.refreshUserProfile()).rejects.toThrow( + "network down", + ); + await expect(provider.refreshUserProfile()).resolves.toBe(true); + + expect(oauth.discoveryRequest).toHaveBeenCalledTimes(2); + }); + + test("deduplicates concurrent discovery requests", async () => { + vi.mocked(oauth.discoveryRequest).mockClear(); + + const provider = createProvider(); + setupAuthenticated(); + + await Promise.all([ + provider.refreshUserProfile(), + provider.refreshUserProfile(), + provider.refreshUserProfile(), + ]); + + expect(oauth.discoveryRequest).toHaveBeenCalledTimes(1); + }); + }); + test("self heals providerData when providerData.type is undefined", async () => { const provider = createProvider(); diff --git a/packages/zudoku/src/lib/authentication/providers/openid.tsx b/packages/zudoku/src/lib/authentication/providers/openid.tsx index 5877189d2..2a6964518 100644 --- a/packages/zudoku/src/lib/authentication/providers/openid.tsx +++ b/packages/zudoku/src/lib/authentication/providers/openid.tsx @@ -46,7 +46,7 @@ export class OpenIDAuthenticationProvider { protected client: oauth.Client; protected issuer: string; - protected authorizationServer: oauth.AuthorizationServer | undefined; + protected authorizationServer: Promise | undefined; protected callbackUrlPath: string; @@ -105,14 +105,16 @@ export class OpenIDAuthenticationProvider } protected async getAuthServer() { - if (!this.authorizationServer) { - const issuerUrl = new URL(this.issuer); - const response = await oauth.discoveryRequest(issuerUrl); - this.authorizationServer = await oauth.processDiscoveryResponse( - issuerUrl, - response, - ); - } + this.authorizationServer ??= (async () => { + try { + const issuerUrl = new URL(this.issuer); + const response = await oauth.discoveryRequest(issuerUrl); + return await oauth.processDiscoveryResponse(issuerUrl, response); + } catch (err) { + this.authorizationServer = undefined; + throw err; + } + })(); return this.authorizationServer; } @@ -239,6 +241,7 @@ export class OpenIDAuthenticationProvider isAuthenticated: true, isPending: false, profile, + profileFetchedAt: Date.now(), }); return true; @@ -438,6 +441,7 @@ export class OpenIDAuthenticationProvider isAuthenticated: false, isPending: false, profile: null, + profileFetchedAt: null, providerData: null, }); return; @@ -544,6 +548,7 @@ export class OpenIDAuthenticationProvider isAuthenticated: true, isPending: false, profile, + profileFetchedAt: Date.now(), }); await this.refreshUserProfile(); diff --git a/packages/zudoku/src/lib/authentication/state.ts b/packages/zudoku/src/lib/authentication/state.ts index d3a619f07..9279db4ea 100644 --- a/packages/zudoku/src/lib/authentication/state.ts +++ b/packages/zudoku/src/lib/authentication/state.ts @@ -25,6 +25,7 @@ export interface AuthState { isAuthenticated: boolean; isPending: boolean; profile: UserProfile | null; + profileFetchedAt: number | null; providerData: ProviderData | null; setAuthenticationPending: () => void; setLoggedOut: () => void; @@ -40,12 +41,14 @@ export const authState = create()( isAuthenticated: false, isPending: true, profile: null, + profileFetchedAt: null, providerData: null, setAuthenticationPending: () => set(() => ({ isAuthenticated: false, isPending: false, profile: null, + profileFetchedAt: null, providerData: null, })), setLoggedOut: () => @@ -53,6 +56,7 @@ export const authState = create()( isAuthenticated: false, isPending: false, profile: null, + profileFetchedAt: null, providerData: null, })), setLoggedIn: ({ profile, providerData }) => @@ -60,6 +64,7 @@ export const authState = create()( isAuthenticated: true, isPending: false, profile, + profileFetchedAt: Date.now(), providerData, })), }), diff --git a/packages/zudoku/src/lib/core/RouteGuard.test.tsx b/packages/zudoku/src/lib/core/RouteGuard.test.tsx index c8f861f85..1625c0e63 100644 --- a/packages/zudoku/src/lib/core/RouteGuard.test.tsx +++ b/packages/zudoku/src/lib/core/RouteGuard.test.tsx @@ -70,6 +70,7 @@ const createWrapper = ({ isPending: false, isAuthEnabled: false, profile: null, + profileFetchedAt: null, providerData: null, setAuthenticationPending: vi.fn(), setLoggedOut: vi.fn(),