Skip to content

Add deprecated batch_size parameter for backward compatibility#393

Draft
Copilot wants to merge 7 commits intomainfrom
copilot/optimize-ray-tracing-algorithm
Draft

Add deprecated batch_size parameter for backward compatibility#393
Copilot wants to merge 7 commits intomainfrom
copilot/optimize-ray-tracing-algorithm

Conversation

Copy link

Copilot AI commented Feb 12, 2026

Description

Tests fail because they pass batch_size=11 to rays_intersect_any_triangle, but the new jax.lax.reduce implementation doesn't accept this parameter.

Changes:

  • Main function: Added batch_size: int | None = None with noqa: ARG001 to accept but ignore the parameter
  • Overload signatures: Added epsilon explicitly and removed **kwargs for cleaner type hints
  • Documentation: Marked parameter as deprecated in docstring

Behavior:

# Both work identically - parameter ignored in new implementation
hits = rays_intersect_any_triangle(rays, dirs, tris, batch_size=11)  # Old tests
hits = rays_intersect_any_triangle(rays, dirs, tris)  # New code

The new implementation uses jax.lax.reduce which processes all triangles without batching, making this parameter obsolete. Retained for backward compatibility only.

Checklist

  • I understand that my contributions need to pass the checks;
  • If I created new functions / methods, I documented them and add type hints;
  • If I modified already existing code, I updated the documentation accordingly;
  • The title of my pull request is a short description of the requested changes.

Note to reviewers

Parameter is marked deprecated but not removed to avoid breaking existing test code. Consider removing in next major version.

Original prompt

I need you to act as a JAX and Computer Graphics expert. I am facing a memory bottleneck in my ray-tracing library and need a specialized implementation of the Möller-Trumbore intersection algorithm.

Context:
I currently have a function rays_intersect_triangles(rays, triangles) that returns intersection details.

  • rays shape: (num_rays, 3, 2) (Origins and Directions) or similar.
  • triangles shape: (num_triangles, 3, 3) (Vertices).
  • The Problem: Applying jax.vmap over both axes creates a (num_rays, num_triangles) boolean matrix. With large inputs, this causes OOM (Out Of Memory) errors.
  • Current Attempt: I tried chunking/batching the vmap, but it is too slow due to compilation overhead and dispatch.

The Goal:
Write a new function rays_intersect_any_triangle(rays, triangles) that returns a boolean array of shape (num_rays,).

  • True if the ray hits at least one triangle.
  • False otherwise.

Technical Constraints & Requirements:

  1. Memory Efficiency: You must strictly avoid materializing the full num_rays x num_triangles interaction matrix.
  2. Algorithm: Use the Möller-Trumbore algorithm, but inline it and specialize it for a boolean result.
  3. Vector Algebra:
    • Do not simply vmap the existing logic.
    • Instead, use jax.numpy.einsum or vector algebra to compute the necessary dot products and cross products for all pairs efficiently.
    • Crucially, attempt to perform a reduction (like sum or any or max) immediately after computing the components (determinant, u, v, t) to collapse the triangle dimension before the memory explodes.
  4. No Full Broadcast: The logic should verify the bounds ($u \ge 0, v \ge 0, u+v \le 1, t > 0$) without broadcasting the full float arrays if possible. Do not use any loop or jax.lax.scan to go over the triangles dimension, because that would prevent efficient parallel execution.

Mathematical Hint:
Recall that for Möller-Trumbore:
$O + tD = (1-u-v)V_0 + uV_1 + vV_2$
You can check if an intersection exists by checking the signs of determinants and dot products relative to each other, potentially avoiding the final division until necessary.

Please generate the JAX code for rays_intersect_any_triangle.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI and others added 3 commits February 12, 2026 18:19
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
… ops

Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Copilot AI changed the title [WIP] Improve memory efficiency of Möller-Trumbore intersection algorithm Vectorize rays_intersect_any_triangle to eliminate batching overhead Feb 12, 2026
Copilot AI requested a review from jeertmans February 12, 2026 18:26
@jeertmans
Copy link
Owner

I don't think your method will work as it still broadcast the num_triangles and num_rays dimension altogether. I think I have found a solution and I would like you to implement it.

  1. Read solution (code below) that is generated for a simpler version of rays_intersect_any_triangle (I omitted the epsilon parameter and the option to apply smoothing.
@jax.jit
def rays_intersect_any_triangle(
    ray_origins: jax.Array,
    ray_directions: jax.Array,
    triangle_vertices: jax.Array
) -> jax.Array:
    """
    Checks if each ray intersects *any* of the provided triangles.
    
    Optimized to use broadcasting and reduction to minimize intermediate 
    memory footprint, avoiding explicit loops or jax.lax.scan.
    
    Args:
        ray_origins: (num_rays, 3) array.
        ray_directions: (num_rays, 3) array.
        triangle_vertices: (num_triangles, 3, 3) array.
        
    Returns:
        (num_rays,) boolean array.
    """
    epsilon = 1e-6
    
    # --- 1. Triangle Pre-computation (Done once for all rays) ---
    # Shapes: (M, 3)
    
    # Vertices
    v0 = triangle_vertices[:, 0, :]
    v1 = triangle_vertices[:, 1, :]
    v2 = triangle_vertices[:, 2, :]
    
    # Edges
    edge1 = v1 - v0
    edge2 = v2 - v0
    
    # Triangle Normal (unnormalized)
    # N = edge1 x edge2
    tri_normal = jnp.cross(edge1, edge2)
    
    # Auxiliary terms for Scalar Triple Product expansions
    # We use these to separate Ray terms from Triangle terms
    
    # Term for t_num: C = v0 . N
    # shape: (M,)
    v0_dot_n = jnp.einsum('mk,mk->m', v0, tri_normal)
    
    # Term for u_num: cross(edge2, v0)
    # shape: (M, 3)
    nx_u = jnp.cross(edge2, v0)
    
    # Term for v_num: cross(edge1, v0)
    # shape: (M, 3)
    nx_v = jnp.cross(edge1, v0)
    
    # --- 3. Vectorization & Reduction ---
    
    # We use jax.lax.reduce to perform a parallel-reduction of the intersection check
    # across all triangles for each ray. This avoids materializing the (N, M) boolean matrix.
    # We pass the pre-computed triangle data as operands to the reduction.
    
    # --- 3. Vectorization & Reduction ---
    
    # We use jax.lax.reduce to perform a parallel-reduction of the intersection check.
    # To use a scalar init_value (False), the operands must be a single array (or compatible PyTree).
    # We pack the 7 triangle vector/scalar components into a single array of shape (M, 19).
    # Layout: 
    # v0 (3), edge1 (3), edge2 (3), normal (3), v0_dot_n (1), nx_u (3), nx_v (3)
    
    # v0_dot_n is (M,), others are (M, 3). Reshape v0_dot_n to (M, 1) for concat.
    # --- 3. Vectorization & Reduction ---
    
    # We use jax.lax.reduce to perform a parallel-reduction of the intersection check.
    # We pass the pre-computed triangle data as operands to the reduction.
    # To avoid the "0-d indexing" error encountered with packed arrays, we pass the 
    # triangle properties as a *tuple of arrays* (M, ...).
    # The reduction iterates over axis 0 (triangles), passing slices (single triangle properties) to the body.
    
    # v0_dot_n is (M,), others are (M, 3) or (M, )
    # We pass them as severeal separate operands to reduce.
    # To reduce N operands, we must have N accumulators.
    # We only care about the first accumulator (the hit boolean).
    # The others are dummy accumulators to satisfy JAX reduction structure.
    
    # Reshape v0_dot_n to (M, 1) to match (M, 3) slicing behavior more closely
    v0_dot_n_2d = v0_dot_n[:, None]
    
    # v0_dot_n is (M,), others are (M, 3) or (M, )
    # jax.lax.reduce requires accumulators to be strictly scalar (0-d).
    # Since we want to pass vectors (M, 3) and accumulate dummy vectors (3,), 
    # JAX rejects this.
    # We must scalarize EVERYTHING.
    # We split all (M, 3) arrays into 3x (M,) arrays.
    
    # Helper to split (M, 3) -> (M,), (M,), (M,)
    def split_vec(arr):
        return (arr[:, 0], arr[:, 1], arr[:, 2])
        
    # Helper to reconstruct (3,) from components
    def stack_vec(c0, c1, c2):
        return jnp.array([c0, c1, c2])
    
    # 1. Prepare scalar operands
    num_tris = v0.shape[0]
    dummy_bool = jnp.zeros((num_tris,), dtype=bool)
    
    # Split triangle vectors
    op_v0 = split_vec(v0)
    op_e1 = split_vec(edge1)
    op_e2 = split_vec(edge2)
    op_n  = split_vec(tri_normal)
    op_v0n = (v0_dot_n,) # Scalar array (M,) tuple
    op_nxu = split_vec(nx_u)
    op_nxv = split_vec(nx_v)
    
    # Combine triangle operands (tuple of 1 + 3*6 + 1 = 20 arrays)
    base_operands = (dummy_bool,) + op_v0 + op_e1 + op_e2 + op_n + op_v0n + op_nxu + op_nxv
    
    def check_any_hit(r_o, r_d):
        """
        Checks if ray (r_o, r_d) hits any triangle using jax.lax.reduce.
        (Parallel reduction over Triangles).
        """
        
        # Broadcast ray data to be operands (M, 3) then split
        r_o_b = jnp.broadcast_to(r_o, (num_tris, 3))
        r_d_b = jnp.broadcast_to(r_d, (num_tris, 3))
        
        op_ro = split_vec(r_o_b)
        op_rd = split_vec(r_d_b)
        
        # All operands: 20 + 3 + 3 = 26 arrays of shape (M,)
        all_operands = base_operands + op_ro + op_rd
        
        # Inits: 26 scalars.
        # 0: False
        # 1-25: 0.0 (float)
        init_vals = (False,) + (jnp.array(0.0),) * 25
        
        # Body function: (acc_tuple, input_tuple) -> acc_tuple
        def reduce_body(acc_seq, input_seq):
            acc_hit = acc_seq[0]
            
            # Reconstruct vectors from scalar inputs
            # Input map:
            # 0: dummy
            # 1-3: v0
            # 4-6: e1
            # 7-9: e2
            # 10-12: n
            # 13: v0n
            # 14-16: nxu
            # 17-19: nxv
            # 20-22: ro
            # 23-25: rd
            
            tv0 = stack_vec(input_seq[1], input_seq[2], input_seq[3])
            te1 = stack_vec(input_seq[4], input_seq[5], input_seq[6])
            te2 = stack_vec(input_seq[7], input_seq[8], input_seq[9])
            tn  = stack_vec(input_seq[10], input_seq[11], input_seq[12])
            tv0n = input_seq[13] # Scalar
            tnx_u = stack_vec(input_seq[14], input_seq[15], input_seq[16])
            tnx_v = stack_vec(input_seq[17], input_seq[18], input_seq[19])
            tr_o = stack_vec(input_seq[20], input_seq[21], input_seq[22])
            tr_d = stack_vec(input_seq[23], input_seq[24], input_seq[25])
            
            # --- Scalar Triple Product Algebra ---
            
            det = -jnp.dot(tr_d, tn)
            inv_det = 1.0 / det
            ray_cross = jnp.cross(tr_o, tr_d)
            t_num = jnp.dot(tr_o, tn) - tv0n
            u_num = jnp.dot(ray_cross, te2) - jnp.dot(tr_d, tnx_u)
            v_num = jnp.dot(tr_d, tnx_v) - jnp.dot(ray_cross, te1)
            
            # --- Intersection Checks ---
            
            u = u_num * inv_det
            v = v_num * inv_det
            t = t_num * inv_det
            
            not_parallel = (det > epsilon) | (det < -epsilon)
            
            hit = (
                not_parallel &
                (u >= 0.0) & (u <= 1.0) &
                (v >= 0.0) & ((u + v) <= 1.0) &
                (t > 0.0)
            )
            
            # Propagate hit accumulator with OR
            # Propagate other accumulators unchanged (dummies)
            return (acc_hit | hit,) + acc_seq[1:]
        
        # Reduce over dimension 0 of the operands (M triangles)
        result_seq = jax.lax.reduce(all_operands, init_vals, reduce_body, (0,))
        
        # We only want the first result (hit boolean)
        return result_seq[0]

    # Map the reduction over all rays
    # tri_operands are captured from closure (broadcasted automatically by vmap).
    hits_per_ray = jax.vmap(check_any_hit)(ray_origins, ray_directions)
    
    return hits_per_ray
  1. Create a single-ray-many-triangles function in https://github.com/jeertmans/DiffeRT/blob/main/differt/src/differt/rt/_utils.py with the following signature and name:
@eqx.filter_jit
def _ray_intersect_any_triangle(
    ray_origin: Float[ArrayLike, "3"],
    ray_direction: Float[ArrayLike, "3"],
    triangle_vertices: Float[ArrayLike, "num_triangles 3 3"],
    active_triangles: Bool[ArrayLike, "num_triangles"] | None = None,
    *,
    epsilon: Float[ArrayLike, ""] | None = None,
    hit_tol: Float[ArrayLike, ""] | None = None,
    smoothing_factor: Float[ArrayLike, ""] | None = None,
) -> Bool[Array, " "] | Float[Array, " "]: ...

that implement the reduction logic on a sequence of triangles. This function should additionally remove unused variables, and implement the logic for smoothing (if smoothing_factor is not None, see rays_intersect_triangles's implementation for reference).
Then, implement

@eqx.filter_jit
def rays_intersect_any_triangle(
    ray_origins: Float[ArrayLike, "*#batch 3"],
    ray_directions: Float[ArrayLike, "*#batch 3"],
    triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"],
    active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None,
    *,
    epsilon: Float[ArrayLike, ""] | None = None,
    hit_tol: Float[ArrayLike, ""] | None = None,
    smoothing_factor: Float[ArrayLike, ""] | None = None,
) -> Bool[Array, " *batch"] | Float[Array, " *batch"]:

simply by wrapping _ray_intersect_any_triangle with jax.numpy.vectorize and partial (for scalar argument).
3. Then add a test in https://github.com/jeertmans/DiffeRT/blob/main/differt/tests/rt/test_utils.py: one for testing that the implementation does not raise OOM error. See example test below:

def test_oom_stress():
    print("\n--- Running OOM Stress Test (1M Rays x 1M Triangles) ---")
    print("Note: This test effectively requests 1 Trillion interaction checks.")
    print("Without chunking (scan/loop), this is expected to stress memory significantly.")
    
    # 100,000 rays and triangles
    N = 100_000
    M = 100_000
    
    # Create dummy tiny arrays to avoid allocating input memory, 
    # but broadcast them to simulate the workload shape.
    # We use jnp.broadcast_to to create the "shapes" without the full data cost initially.
    
    # We use small actual arrays and broadcast them to test the logic's memory limit
    # without running out of RAM just creating the inputs.
    ray_o = jnp.zeros((1, 3))
    ray_d = jnp.array([[0.0, 0.0, 1.0]])
    tri = jnp.array([[[0.0, 0.0, 5.0], [1.0, 0.0, 5.0], [0.0, 1.0, 5.0]]])
    
    big_rays_o = jnp.broadcast_to(ray_o, (N, 3))
    big_rays_d = jnp.broadcast_to(ray_d, (N, 3))
    big_tris = jnp.broadcast_to(tri, (M, 3, 3))
    
    print(f"Input shapes virtualized: Rays {big_rays_o.shape}, Tris {big_tris.shape}")

    try:  # We expected OOM here
        vmap_tris = jax.vmap(ray_intersect_triangle, in_axes=(None, None, 0))
    
        # vmap_rays takes (N, 3), (N, 3), (M, 3, 3) -> (N, M)
        vmap_rays = jax.vmap(vmap_tris, in_axes=(0, 0, None))
    except Exception as e:
        out = vmap_rays(big_rays_o, big_rays_d, big_tris).block_until_ready()


    
    try:  # We do not expect OOM here
        start = time.time()
        # We don't block_until_ready immediately to see if graph construction fails
        res = rays_intersect_any_triangle(big_rays_o, big_rays_d, big_tris)
        res.block_until_ready()
        print(f"SUCCESS: Computed 1T interactions in {time.time() - start:.2f}s")
        print(f"Result shape: {res.shape}")
    except Exception as e:
        print("FAILURE: Test crashed as expected with massive allocation constraints.")
        print(f"Error: {e}")

Copilot AI and others added 2 commits February 13, 2026 12:54
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Copilot AI changed the title Vectorize rays_intersect_any_triangle to eliminate batching overhead Replace vectorized approach with jax.lax.reduce to eliminate memory bottleneck Feb 13, 2026
@github-actions github-actions bot added python Changes Python code tests Changes tests labels Feb 13, 2026
@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 34.20%. Comparing base (b8929b8) to head (3dc4244).

❗ There is a different number of reports uploaded between BASE (b8929b8) and HEAD (3dc4244). Click for more details.

HEAD has 22 uploads less than BASE
Flag BASE (b8929b8) HEAD (3dc4244)
23 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #393       +/-   ##
===========================================
- Coverage   84.95%   34.20%   -50.76%     
===========================================
  Files          32        4       -28     
  Lines        3051      538     -2513     
===========================================
- Hits         2592      184     -2408     
+ Misses        459      354      -105     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Copilot AI changed the title Replace vectorized approach with jax.lax.reduce to eliminate memory bottleneck Add deprecated batch_size parameter for backward compatibility Feb 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

python Changes Python code tests Changes tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants