Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 58 additions & 28 deletions packages/fiber/src/core/hooks/useLoader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,57 @@ const memoizedLoaders = new WeakMap<ConstructorRepresentation<LoaderLike>, Loade
const isConstructor = (value: unknown): value is ConstructorRepresentation<LoaderLike> =>
typeof value === 'function' && value?.prototype?.constructor === value

//* Loader Retrieval Utility ==============================

/**
* Gets or creates a memoized loader instance from a loader constructor or returns the loader if it's already an instance.
* This allows external code to access loader methods like abort().
*/
function getLoader<L extends LoaderLike | ConstructorRepresentation<LoaderLike>>(
Proto: L,
): L extends ConstructorRepresentation<infer T> ? T : L {
// Construct and cache loader if constructor was passed
if (isConstructor(Proto)) {
let loader = memoizedLoaders.get(Proto)
if (!loader) {
loader = new Proto()
memoizedLoaders.set(Proto, loader)
}
return loader as L extends ConstructorRepresentation<infer T> ? T : L
}

// Return the loader instance as-is
return Proto as L extends ConstructorRepresentation<infer T> ? T : L
}

function loadingFn<L extends LoaderLike | ConstructorRepresentation<LoaderLike>>(
extensions?: Extensions<L>,
onProgress?: (event: ProgressEvent<EventTarget>) => void,
) {
return function (Proto: L, ...input: string[]) {
let loader: LoaderLike = Proto as any

// Construct and cache loader if constructor was passed
if (isConstructor(Proto)) {
loader = memoizedLoaders.get(Proto)!
if (!loader) {
loader = new Proto()
memoizedLoaders.set(Proto, loader)
}
}
return function (Proto: L, input: string) {
const loader = getLoader(Proto)

// Apply loader extensions
if (extensions) extensions(loader as any)

// Go through the urls and load them
return Promise.all(
input.map(
(input) =>
new Promise<LoaderResult<L>>((res, reject) =>
loader.load(
input,
(data: any) => {
if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene))
res(data)
},
onProgress,
(error: unknown) => reject(new Error(`Could not load ${input}: ${(error as ErrorEvent)?.message}`)),
),
),
// Prefer loadAsync if available (supports abort, cleaner Promise API)
if ('loadAsync' in loader && typeof loader.loadAsync === 'function') {
return loader.loadAsync(input, onProgress).then((data: any) => {
if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene))
return data
}) as Promise<LoaderResult<L>>
}

// Fall back to callback-based load
return new Promise<LoaderResult<L>>((res, reject) =>
loader.load(
input,
(data: any) => {
if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene))
res(data)
},
onProgress,
(error: unknown) => reject(new Error(`Could not load ${input}: ${(error as ErrorEvent)?.message}`)),
),
)
}
Expand All @@ -63,7 +80,13 @@ export function useLoader<I extends InputLike, L extends LoaderLike | Constructo
) {
// Use suspense to load async assets
const keys = (Array.isArray(input) ? input : [input]) as string[]
const results = suspend(loadingFn(extensions, onProgress), [loader, ...keys], { equal: is.equ })

// Create the loading function once to ensure consistent function reference across suspend calls
const fn = loadingFn(extensions, onProgress)

// Call suspend individually for each key to match preload cache structure
const results = keys.map((key) => suspend(fn, [loader, key], { equal: is.equ }))

// Return the object(s)
return (Array.isArray(input) ? results : results[0]) as I extends any[] ? LoaderResult<L>[] : LoaderResult<L>
}
Expand All @@ -75,10 +98,11 @@ useLoader.preload = function <I extends InputLike, L extends LoaderLike | Constr
loader: L,
input: I,
extensions?: Extensions<L>,
onProgress?: (event: ProgressEvent<EventTarget>) => void,
): void {
const keys = (Array.isArray(input) ? input : [input]) as string[]
// Preload each key individually so cache keys match useLoader calls
keys.forEach((key) => preload(loadingFn(extensions), [loader, key]))
keys.forEach((key) => preload(loadingFn(extensions, onProgress), [loader, key]))
}

/**
Expand All @@ -92,3 +116,9 @@ useLoader.clear = function <I extends InputLike, L extends LoaderLike | Construc
// Clear each key individually to match how they were cached
keys.forEach((key) => clear([loader, key]))
}

/**
* Gets the memoized loader instance, allowing access to loader methods like abort().
* For constructor-based loaders, returns the cached instance. For instance loaders, returns the instance itself.
*/
useLoader.loader = getLoader
4 changes: 2 additions & 2 deletions packages/fiber/src/core/renderer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { R3F_BUILD_LEGACY, R3F_BUILD_WEBGPU, WebGLRenderer, WebGPURenderer, Insp

import type { Object3D } from '#three'
import type { JSX, ReactNode, RefObject } from 'react'
import { useMemo, useState } from 'react'
import { useCallback, useMemo, useState } from 'react'
import { ConcurrentRoot } from 'react-reconciler/constants'
import { createWithEqualityFn } from 'zustand/traditional'

Expand Down Expand Up @@ -651,7 +651,7 @@ interface PortalWrapperProps {

//* Portal Wrapper - Handles Ref Resolution ==============================
function PortalWrapper({ children, container, state }: PortalWrapperProps): JSX.Element {
const isRef = (obj: any): obj is RefObject<Object3D> => obj && 'current' in obj
const isRef = useCallback((obj: any): obj is RefObject<Object3D> => obj && 'current' in obj, [])
const [resolvedContainer, setResolvedContainer] = useState<Object3D | null>(() => {
if (isRef(container)) return container.current ?? null
return container as Object3D
Expand Down
206 changes: 8 additions & 198 deletions packages/fiber/tests/hooks.test.tsx
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
import * as React from 'react'
import { act } from 'react'
import * as THREE from 'three'
import { createCanvas } from '../../test-renderer/src/createTestCanvas'

import {
createRoot,
advance,
useLoader,
useThree,
useGraph,
useFrame,
ObjectMap,
useInstanceHandle,
Instance,
extend,
} from '../src'
import { createRoot, useThree, useGraph, ObjectMap, useInstanceHandle, Instance, extend } from '../src'

extend(THREE as any)
const root = createRoot(document.createElement('canvas'))

describe('hooks', () => {
let canvas: HTMLCanvasElement = null!
let root: ReturnType<typeof createRoot> = null!

beforeEach(() => {
canvas = createCanvas()
root = createRoot(document.createElement('canvas'))
})

afterEach(async () => {
await act(async () => root.unmount())
})

it('can handle useThree hook', async () => {
Expand Down Expand Up @@ -61,188 +52,7 @@ describe('hooks', () => {
expect(result.size).toEqual({ height: 0, width: 0, top: 0, left: 0 })
})

it('can handle useFrame hook', async () => {
const frameCalls: number[] = []

const Component = () => {
const ref = React.useRef<THREE.Mesh>(null!)
useFrame((_, delta) => {
frameCalls.push(delta)
ref.current.position.x = 1
})

return (
<mesh ref={ref}>
<boxGeometry args={[2, 2]} />
<meshBasicMaterial />
</mesh>
)
}

const store = await act(async () => (await root.configure({ frameloop: 'never' })).render(<Component />))
const { scene } = store.getState()

advance(Date.now())
expect(scene.children[0].position.x).toEqual(1)
expect(frameCalls.length).toBeGreaterThan(0)
})

it('can handle useLoader hook', async () => {
const MockMesh = new THREE.Mesh()
MockMesh.name = 'Scene'

interface GLTF {
scene: THREE.Object3D
}
class GLTFLoader extends THREE.Loader<GLTF, string> {
load(url: string, onLoad: (gltf: GLTF) => void): void {
onLoad({ scene: MockMesh })
}
}

let gltf!: GLTF & ObjectMap
const Component = () => {
gltf = useLoader(GLTFLoader, '/suzanne.glb')
return <primitive object={gltf.scene} />
}

const store = await act(async () => root.render(<Component />))
const { scene } = store.getState()

expect(scene.children[0]).toBe(MockMesh)
expect(gltf.scene).toBe(MockMesh)
expect(gltf.nodes.Scene).toBe(MockMesh)
})

it('can handle useLoader hook with an array of strings', async () => {
const MockMesh = new THREE.Mesh()

const MockGroup = new THREE.Group()
const mat1 = new THREE.MeshBasicMaterial()
mat1.name = 'Mat 1'
const mesh1 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat1)
mesh1.name = 'Mesh 1'
const mat2 = new THREE.MeshBasicMaterial()
mat2.name = 'Mat 2'
const mesh2 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat2)
mesh2.name = 'Mesh 2'
MockGroup.add(mesh1, mesh2)

class TestLoader extends THREE.Loader {
load = jest
.fn()
.mockImplementationOnce((_url, onLoad) => {
onLoad(MockMesh)
})
.mockImplementationOnce((_url, onLoad) => {
onLoad(MockGroup)
})
}

const extensions = jest.fn()

const Component = () => {
const [mockMesh, mockScene] = useLoader(TestLoader, ['/suzanne.glb', '/myModels.glb'], extensions)

return (
<>
<primitive object={mockMesh as THREE.Mesh} />
<primitive object={mockScene as THREE.Scene} />
</>
)
}

const store = await act(async () => root.render(<Component />))
const { scene } = store.getState()

expect(scene.children[0]).toBe(MockMesh)
expect(scene.children[1]).toBe(MockGroup)
expect(extensions).toHaveBeenCalledTimes(1)
})

it('can handle useLoader with an existing loader instance', async () => {
class Loader extends THREE.Loader<null, string> {
load(_url: string, onLoad: (result: null) => void): void {
onLoad(null)
}
}

const loader = new Loader()
let proto!: Loader

function Test(): null {
return useLoader(loader, '', (loader) => (proto = loader))
}
await act(async () => root.render(<Test />))

expect(proto).toBe(loader)
})

it('can handle useLoader with a loader extension', async () => {
class Loader extends THREE.Loader<null, string> {
load(_url: string, onLoad: (result: null) => void): void {
onLoad(null)
}
}

let proto!: Loader

function Test(): null {
return useLoader(Loader, '', (loader) => (proto = loader))
}
await act(async () => root.render(<Test />))

expect(proto).toBeInstanceOf(Loader)
})

it('useLoader.preload with array caches each URL individually', async () => {
const loadCalls: string[] = []

class TestLoader extends THREE.Loader<string, string> {
load(url: string, onLoad: (result: string) => void): void {
loadCalls.push(url)
onLoad(`loaded:${url}`)
}
}

const URL_A = '/model-a.glb'
const URL_B = '/model-b.glb'

// Preload with an array - this should cache each URL individually
useLoader.preload(TestLoader, [URL_A, URL_B])

// Wait for preload promises to resolve
await new Promise((resolve) => setTimeout(resolve, 10))

// Clear load tracking to isolate the useLoader calls
const preloadCallCount = loadCalls.length
expect(preloadCallCount).toBe(2) // Both URLs should have been loaded

// Now use useLoader with individual URLs - should hit cache, not reload
let resultA: string | undefined
let resultB: string | undefined

const ComponentA = () => {
resultA = useLoader(TestLoader, URL_A)
return null
}

const ComponentB = () => {
resultB = useLoader(TestLoader, URL_B)
return null
}

await act(async () => root.render(<ComponentA />))
await act(async () => root.render(<ComponentB />))

// The loader should NOT have been called again - cache should have been hit
expect(loadCalls.length).toBe(2) // Still just the 2 preload calls
expect(resultA).toBe(`loaded:${URL_A}`)
expect(resultB).toBe(`loaded:${URL_B}`)

// Clean up cache for other tests
useLoader.clear(TestLoader, [URL_A, URL_B])
})
// Note: useFrame has its own dedicated test file (useFrame.test.tsx)

it('can handle useGraph hook', async () => {
const group = new THREE.Group()
Expand Down
6 changes: 3 additions & 3 deletions packages/fiber/tests/useFrame.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,9 @@ describe('useFrame hook', () => {
await new Promise((resolve) => setTimeout(resolve, 100))
})

// Verify error was set in store
const state = store.getState()
expect(state.error).toBe(testError)
// Verify error was set in store (only extract the error property to avoid circular references)
const error = store.getState().error
expect(error).toBe(testError)
})

//* Legacy Priority Tests ==============================
Expand Down
Loading