diff --git a/packages/fiber/src/core/hooks/useLoader.tsx b/packages/fiber/src/core/hooks/useLoader.tsx index 165acfec30..e029e98753 100644 --- a/packages/fiber/src/core/hooks/useLoader.tsx +++ b/packages/fiber/src/core/hooks/useLoader.tsx @@ -10,40 +10,57 @@ const memoizedLoaders = new WeakMap, Loade const isConstructor = (value: unknown): value is ConstructorRepresentation => typeof value === 'function' && value?.prototype?.constructor === value +//* Loader Retrieval Utility ============================== + +/** + * Gets or creates a memoized loader instance from a loader constructor or returns the loader if it's already an instance. + * This allows external code to access loader methods like abort(). + */ +function getLoader>( + Proto: L, +): L extends ConstructorRepresentation ? T : L { + // Construct and cache loader if constructor was passed + if (isConstructor(Proto)) { + let loader = memoizedLoaders.get(Proto) + if (!loader) { + loader = new Proto() + memoizedLoaders.set(Proto, loader) + } + return loader as L extends ConstructorRepresentation ? T : L + } + + // Return the loader instance as-is + return Proto as L extends ConstructorRepresentation ? T : L +} + function loadingFn>( extensions?: Extensions, onProgress?: (event: ProgressEvent) => void, ) { - return function (Proto: L, ...input: string[]) { - let loader: LoaderLike = Proto as any - - // Construct and cache loader if constructor was passed - if (isConstructor(Proto)) { - loader = memoizedLoaders.get(Proto)! - if (!loader) { - loader = new Proto() - memoizedLoaders.set(Proto, loader) - } - } + return function (Proto: L, input: string) { + const loader = getLoader(Proto) // Apply loader extensions if (extensions) extensions(loader as any) - // Go through the urls and load them - return Promise.all( - input.map( - (input) => - new Promise>((res, reject) => - loader.load( - input, - (data: any) => { - if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene)) - res(data) - }, - onProgress, - (error: unknown) => reject(new Error(`Could not load ${input}: ${(error as ErrorEvent)?.message}`)), - ), - ), + // Prefer loadAsync if available (supports abort, cleaner Promise API) + if ('loadAsync' in loader && typeof loader.loadAsync === 'function') { + return loader.loadAsync(input, onProgress).then((data: any) => { + if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene)) + return data + }) as Promise> + } + + // Fall back to callback-based load + return new Promise>((res, reject) => + loader.load( + input, + (data: any) => { + if (isObject3D(data?.scene)) Object.assign(data, buildGraph(data.scene)) + res(data) + }, + onProgress, + (error: unknown) => reject(new Error(`Could not load ${input}: ${(error as ErrorEvent)?.message}`)), ), ) } @@ -63,7 +80,13 @@ export function useLoader suspend(fn, [loader, key], { equal: is.equ })) + // Return the object(s) return (Array.isArray(input) ? results : results[0]) as I extends any[] ? LoaderResult[] : LoaderResult } @@ -75,10 +98,11 @@ useLoader.preload = function , + onProgress?: (event: ProgressEvent) => void, ): void { const keys = (Array.isArray(input) ? input : [input]) as string[] // Preload each key individually so cache keys match useLoader calls - keys.forEach((key) => preload(loadingFn(extensions), [loader, key])) + keys.forEach((key) => preload(loadingFn(extensions, onProgress), [loader, key])) } /** @@ -92,3 +116,9 @@ useLoader.clear = function clear([loader, key])) } + +/** + * Gets the memoized loader instance, allowing access to loader methods like abort(). + * For constructor-based loaders, returns the cached instance. For instance loaders, returns the instance itself. + */ +useLoader.loader = getLoader diff --git a/packages/fiber/src/core/renderer.tsx b/packages/fiber/src/core/renderer.tsx index a975f1d527..2ebb747ebc 100644 --- a/packages/fiber/src/core/renderer.tsx +++ b/packages/fiber/src/core/renderer.tsx @@ -3,7 +3,7 @@ import { R3F_BUILD_LEGACY, R3F_BUILD_WEBGPU, WebGLRenderer, WebGPURenderer, Insp import type { Object3D } from '#three' import type { JSX, ReactNode, RefObject } from 'react' -import { useMemo, useState } from 'react' +import { useCallback, useMemo, useState } from 'react' import { ConcurrentRoot } from 'react-reconciler/constants' import { createWithEqualityFn } from 'zustand/traditional' @@ -651,7 +651,7 @@ interface PortalWrapperProps { //* Portal Wrapper - Handles Ref Resolution ============================== function PortalWrapper({ children, container, state }: PortalWrapperProps): JSX.Element { - const isRef = (obj: any): obj is RefObject => obj && 'current' in obj + const isRef = useCallback((obj: any): obj is RefObject => obj && 'current' in obj, []) const [resolvedContainer, setResolvedContainer] = useState(() => { if (isRef(container)) return container.current ?? null return container as Object3D diff --git a/packages/fiber/tests/hooks.test.tsx b/packages/fiber/tests/hooks.test.tsx index 7d1b241f9f..0b07f889ab 100644 --- a/packages/fiber/tests/hooks.test.tsx +++ b/packages/fiber/tests/hooks.test.tsx @@ -1,29 +1,20 @@ import * as React from 'react' import { act } from 'react' import * as THREE from 'three' -import { createCanvas } from '../../test-renderer/src/createTestCanvas' -import { - createRoot, - advance, - useLoader, - useThree, - useGraph, - useFrame, - ObjectMap, - useInstanceHandle, - Instance, - extend, -} from '../src' +import { createRoot, useThree, useGraph, ObjectMap, useInstanceHandle, Instance, extend } from '../src' extend(THREE as any) -const root = createRoot(document.createElement('canvas')) describe('hooks', () => { - let canvas: HTMLCanvasElement = null! + let root: ReturnType = null! beforeEach(() => { - canvas = createCanvas() + root = createRoot(document.createElement('canvas')) + }) + + afterEach(async () => { + await act(async () => root.unmount()) }) it('can handle useThree hook', async () => { @@ -61,188 +52,7 @@ describe('hooks', () => { expect(result.size).toEqual({ height: 0, width: 0, top: 0, left: 0 }) }) - it('can handle useFrame hook', async () => { - const frameCalls: number[] = [] - - const Component = () => { - const ref = React.useRef(null!) - useFrame((_, delta) => { - frameCalls.push(delta) - ref.current.position.x = 1 - }) - - return ( - - - - - ) - } - - const store = await act(async () => (await root.configure({ frameloop: 'never' })).render()) - const { scene } = store.getState() - - advance(Date.now()) - expect(scene.children[0].position.x).toEqual(1) - expect(frameCalls.length).toBeGreaterThan(0) - }) - - it('can handle useLoader hook', async () => { - const MockMesh = new THREE.Mesh() - MockMesh.name = 'Scene' - - interface GLTF { - scene: THREE.Object3D - } - class GLTFLoader extends THREE.Loader { - load(url: string, onLoad: (gltf: GLTF) => void): void { - onLoad({ scene: MockMesh }) - } - } - - let gltf!: GLTF & ObjectMap - const Component = () => { - gltf = useLoader(GLTFLoader, '/suzanne.glb') - return - } - - const store = await act(async () => root.render()) - const { scene } = store.getState() - - expect(scene.children[0]).toBe(MockMesh) - expect(gltf.scene).toBe(MockMesh) - expect(gltf.nodes.Scene).toBe(MockMesh) - }) - - it('can handle useLoader hook with an array of strings', async () => { - const MockMesh = new THREE.Mesh() - - const MockGroup = new THREE.Group() - const mat1 = new THREE.MeshBasicMaterial() - mat1.name = 'Mat 1' - const mesh1 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat1) - mesh1.name = 'Mesh 1' - const mat2 = new THREE.MeshBasicMaterial() - mat2.name = 'Mat 2' - const mesh2 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat2) - mesh2.name = 'Mesh 2' - MockGroup.add(mesh1, mesh2) - - class TestLoader extends THREE.Loader { - load = jest - .fn() - .mockImplementationOnce((_url, onLoad) => { - onLoad(MockMesh) - }) - .mockImplementationOnce((_url, onLoad) => { - onLoad(MockGroup) - }) - } - - const extensions = jest.fn() - - const Component = () => { - const [mockMesh, mockScene] = useLoader(TestLoader, ['/suzanne.glb', '/myModels.glb'], extensions) - - return ( - <> - - - - ) - } - - const store = await act(async () => root.render()) - const { scene } = store.getState() - - expect(scene.children[0]).toBe(MockMesh) - expect(scene.children[1]).toBe(MockGroup) - expect(extensions).toHaveBeenCalledTimes(1) - }) - - it('can handle useLoader with an existing loader instance', async () => { - class Loader extends THREE.Loader { - load(_url: string, onLoad: (result: null) => void): void { - onLoad(null) - } - } - - const loader = new Loader() - let proto!: Loader - - function Test(): null { - return useLoader(loader, '', (loader) => (proto = loader)) - } - await act(async () => root.render()) - - expect(proto).toBe(loader) - }) - - it('can handle useLoader with a loader extension', async () => { - class Loader extends THREE.Loader { - load(_url: string, onLoad: (result: null) => void): void { - onLoad(null) - } - } - - let proto!: Loader - - function Test(): null { - return useLoader(Loader, '', (loader) => (proto = loader)) - } - await act(async () => root.render()) - - expect(proto).toBeInstanceOf(Loader) - }) - - it('useLoader.preload with array caches each URL individually', async () => { - const loadCalls: string[] = [] - - class TestLoader extends THREE.Loader { - load(url: string, onLoad: (result: string) => void): void { - loadCalls.push(url) - onLoad(`loaded:${url}`) - } - } - - const URL_A = '/model-a.glb' - const URL_B = '/model-b.glb' - - // Preload with an array - this should cache each URL individually - useLoader.preload(TestLoader, [URL_A, URL_B]) - - // Wait for preload promises to resolve - await new Promise((resolve) => setTimeout(resolve, 10)) - - // Clear load tracking to isolate the useLoader calls - const preloadCallCount = loadCalls.length - expect(preloadCallCount).toBe(2) // Both URLs should have been loaded - - // Now use useLoader with individual URLs - should hit cache, not reload - let resultA: string | undefined - let resultB: string | undefined - - const ComponentA = () => { - resultA = useLoader(TestLoader, URL_A) - return null - } - - const ComponentB = () => { - resultB = useLoader(TestLoader, URL_B) - return null - } - - await act(async () => root.render()) - await act(async () => root.render()) - - // The loader should NOT have been called again - cache should have been hit - expect(loadCalls.length).toBe(2) // Still just the 2 preload calls - expect(resultA).toBe(`loaded:${URL_A}`) - expect(resultB).toBe(`loaded:${URL_B}`) - - // Clean up cache for other tests - useLoader.clear(TestLoader, [URL_A, URL_B]) - }) + // Note: useFrame has its own dedicated test file (useFrame.test.tsx) it('can handle useGraph hook', async () => { const group = new THREE.Group() diff --git a/packages/fiber/tests/useFrame.test.tsx b/packages/fiber/tests/useFrame.test.tsx index 0ed5a16978..7971ad3580 100644 --- a/packages/fiber/tests/useFrame.test.tsx +++ b/packages/fiber/tests/useFrame.test.tsx @@ -820,9 +820,9 @@ describe('useFrame hook', () => { await new Promise((resolve) => setTimeout(resolve, 100)) }) - // Verify error was set in store - const state = store.getState() - expect(state.error).toBe(testError) + // Verify error was set in store (only extract the error property to avoid circular references) + const error = store.getState().error + expect(error).toBe(testError) }) //* Legacy Priority Tests ============================== diff --git a/packages/fiber/tests/useLoader.test.tsx b/packages/fiber/tests/useLoader.test.tsx new file mode 100644 index 0000000000..849d5cfcea --- /dev/null +++ b/packages/fiber/tests/useLoader.test.tsx @@ -0,0 +1,298 @@ +import * as React from 'react' +import { act } from 'react' +import * as THREE from 'three' + +import { createRoot, useLoader, ObjectMap, extend } from '../src' + +extend(THREE as any) + +describe('useLoader', () => { + let root: ReturnType = null! + + beforeEach(() => { + root = createRoot(document.createElement('canvas')) + }) + + afterEach(async () => { + await act(async () => root.unmount()) + }) + + it('can handle useLoader hook', async () => { + const MockMesh = new THREE.Mesh() + MockMesh.name = 'Scene' + + interface GLTF { + scene: THREE.Object3D + } + class GLTFLoader extends THREE.Loader { + load(url: string, onLoad: (gltf: GLTF) => void): void { + onLoad({ scene: MockMesh }) + } + } + + let gltf!: GLTF & ObjectMap + const Component = () => { + gltf = useLoader(GLTFLoader, '/suzanne.glb') + return + } + + const store = await act(async () => root.render()) + const { scene } = store.getState() + + expect(scene.children[0]).toBe(MockMesh) + expect(gltf.scene).toBe(MockMesh) + expect(gltf.nodes.Scene).toBe(MockMesh) + }) + + it('can handle useLoader hook with an array of strings', async () => { + // Use unique URLs to avoid any cache collision with other tests + const URL_MESH = '/array-test-mesh.glb' + const URL_GROUP = '/array-test-group.glb' + + const MockMesh = new THREE.Mesh() + const MockGroup = new THREE.Group() + const mat1 = new THREE.MeshBasicMaterial() + mat1.name = 'Mat 1' + const mesh1 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat1) + mesh1.name = 'Mesh 1' + const mat2 = new THREE.MeshBasicMaterial() + mat2.name = 'Mat 2' + const mesh2 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat2) + mesh2.name = 'Mesh 2' + MockGroup.add(mesh1, mesh2) + + // Use URL-based mock instead of mockImplementationOnce since + // individual suspend calls may re-render and call load multiple times + class ArrayTestLoader extends THREE.Loader { + load = jest.fn((url: string, onLoad: (result: any) => void) => { + if (url === URL_MESH) onLoad(MockMesh) + else if (url === URL_GROUP) onLoad(MockGroup) + }) + } + + const extensions = jest.fn() + + const Component = () => { + const [mockMesh, mockScene] = useLoader(ArrayTestLoader, [URL_MESH, URL_GROUP], extensions) + + return ( + <> + + + + ) + } + + const store = await act(async () => root.render()) + const { scene } = store.getState() + + expect(scene.children[0]).toBe(MockMesh) + expect(scene.children[1]).toBe(MockGroup) + // Extensions called once per URL (may be called more due to re-renders, but at least 2) + expect(extensions.mock.calls.length).toBeGreaterThanOrEqual(2) + + // Clean up cache + useLoader.clear(ArrayTestLoader, [URL_MESH, URL_GROUP]) + }) + + it('can handle useLoader with an existing loader instance', async () => { + class Loader extends THREE.Loader { + load(_url: string, onLoad: (result: null) => void): void { + onLoad(null) + } + } + + const loader = new Loader() + let proto!: Loader + + function Test(): null { + return useLoader(loader, '', (loader) => (proto = loader)) + } + await act(async () => root.render()) + + expect(proto).toBe(loader) + }) + + it('can handle useLoader with a loader extension', async () => { + class Loader extends THREE.Loader { + load(_url: string, onLoad: (result: null) => void): void { + onLoad(null) + } + } + + let proto!: Loader + + function Test(): null { + return useLoader(Loader, '', (loader) => (proto = loader)) + } + await act(async () => root.render()) + + expect(proto).toBeInstanceOf(Loader) + }) + + it('useLoader.preload with array caches each URL individually', async () => { + const loadCalls: string[] = [] + + class TestLoader extends THREE.Loader { + load(url: string, onLoad: (result: string) => void): void { + loadCalls.push(url) + onLoad(`loaded:${url}`) + } + } + + const URL_A = '/model-a.glb' + const URL_B = '/model-b.glb' + + // Preload with an array - this should cache each URL individually + useLoader.preload(TestLoader, [URL_A, URL_B]) + + // Wait for preload promises to resolve + await new Promise((resolve) => setTimeout(resolve, 10)) + + // Clear load tracking to isolate the useLoader calls + const preloadCallCount = loadCalls.length + expect(preloadCallCount).toBe(2) // Both URLs should have been loaded + + // Now use useLoader with individual URLs - should hit cache, not reload + let resultA: string | undefined + let resultB: string | undefined + + const ComponentA = () => { + resultA = useLoader(TestLoader, URL_A) + return null + } + + const ComponentB = () => { + resultB = useLoader(TestLoader, URL_B) + return null + } + + await act(async () => root.render()) + await act(async () => root.render()) + + // The loader should NOT have been called again - cache should have been hit + expect(loadCalls.length).toBe(2) // Still just the 2 preload calls + expect(resultA).toBe(`loaded:${URL_A}`) + expect(resultB).toBe(`loaded:${URL_B}`) + + // Clean up cache for other tests + useLoader.clear(TestLoader, [URL_A, URL_B]) + }) + + it('can abort loader with loadAsync and clears suspend', async () => { + const URL_SLOW = '/slow-load.glb' + let abortCalled = false + let loadCompleted = false + let loadAsyncCalled = false + + class AbortableLoader extends THREE.Loader { + private abortController: (() => void) | null = null + + // Implement loadAsync for abort support + loadAsync(url: string, onProgress?: (event: ProgressEvent) => void): Promise { + loadAsyncCalled = true + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + loadCompleted = true + this.abortController = null + resolve(`loaded:${url}`) + }, 5000) // 5 second load time + + // Store abort handler + this.abortController = () => { + abortCalled = true + clearTimeout(timeoutId) + reject(new Error('Load aborted')) + } + }) + } + + // Standard load method (won't be used since loadAsync exists) + load(url: string, onLoad: (result: string) => void): void { + onLoad(`loaded:${url}`) + } + + // Override abort to call our controller + abort(): this { + if (this.abortController) { + this.abortController() + this.abortController = null + } + return super.abort() + } + } + + const Component = () => { + const result = useLoader(AbortableLoader, URL_SLOW) + return + } + + // Start rendering (will suspend and trigger loadAsync) + let renderError: any = null + act(() => { + try { + root.render( + }> + + , + ) + } catch (err) { + renderError = err + } + }) + + // Wait for loadAsync to be called by suspend-react + await new Promise((resolve) => setTimeout(resolve, 100)) + expect(loadAsyncCalled).toBe(true) + expect(loadCompleted).toBe(false) + + // Get the loader instance and abort after 1 second + await new Promise((resolve) => setTimeout(resolve, 1000)) + const loaderInstance = useLoader.loader(AbortableLoader) + loaderInstance.abort() + + // Verify abort was called + expect(abortCalled).toBe(true) + expect(loadCompleted).toBe(false) + + // Wait a bit for the abort to propagate + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Clear the cache to ensure suspend is cleared + useLoader.clear(AbortableLoader, URL_SLOW) + + // Reset flags for next test + loadAsyncCalled = false + abortCalled = false + + // Verify the suspended load is no longer cached + // If we try to use it again, it should start a fresh load + const Component2 = () => { + const result = useLoader(AbortableLoader, URL_SLOW) + return + } + + // This should trigger a new load since cache was cleared + act(() => { + try { + root.render( + }> + + , + ) + } catch (err) { + // Expected to suspend again + } + }) + + // Wait to let the new load attempt start + await new Promise((resolve) => setTimeout(resolve, 100)) + expect(loadAsyncCalled).toBe(true) // New load started + expect(abortCalled).toBe(false) // No abort on this new attempt yet + + // Clean up + const newLoader = useLoader.loader(AbortableLoader) + newLoader.abort() + useLoader.clear(AbortableLoader, URL_SLOW) + }) +})