Add deprecated batch_size parameter for backward compatibility#393
Add deprecated batch_size parameter for backward compatibility#393
Conversation
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
… ops Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
|
I don't think your method will work as it still broadcast the
@jax.jit
def rays_intersect_any_triangle(
ray_origins: jax.Array,
ray_directions: jax.Array,
triangle_vertices: jax.Array
) -> jax.Array:
"""
Checks if each ray intersects *any* of the provided triangles.
Optimized to use broadcasting and reduction to minimize intermediate
memory footprint, avoiding explicit loops or jax.lax.scan.
Args:
ray_origins: (num_rays, 3) array.
ray_directions: (num_rays, 3) array.
triangle_vertices: (num_triangles, 3, 3) array.
Returns:
(num_rays,) boolean array.
"""
epsilon = 1e-6
# --- 1. Triangle Pre-computation (Done once for all rays) ---
# Shapes: (M, 3)
# Vertices
v0 = triangle_vertices[:, 0, :]
v1 = triangle_vertices[:, 1, :]
v2 = triangle_vertices[:, 2, :]
# Edges
edge1 = v1 - v0
edge2 = v2 - v0
# Triangle Normal (unnormalized)
# N = edge1 x edge2
tri_normal = jnp.cross(edge1, edge2)
# Auxiliary terms for Scalar Triple Product expansions
# We use these to separate Ray terms from Triangle terms
# Term for t_num: C = v0 . N
# shape: (M,)
v0_dot_n = jnp.einsum('mk,mk->m', v0, tri_normal)
# Term for u_num: cross(edge2, v0)
# shape: (M, 3)
nx_u = jnp.cross(edge2, v0)
# Term for v_num: cross(edge1, v0)
# shape: (M, 3)
nx_v = jnp.cross(edge1, v0)
# --- 3. Vectorization & Reduction ---
# We use jax.lax.reduce to perform a parallel-reduction of the intersection check
# across all triangles for each ray. This avoids materializing the (N, M) boolean matrix.
# We pass the pre-computed triangle data as operands to the reduction.
# --- 3. Vectorization & Reduction ---
# We use jax.lax.reduce to perform a parallel-reduction of the intersection check.
# To use a scalar init_value (False), the operands must be a single array (or compatible PyTree).
# We pack the 7 triangle vector/scalar components into a single array of shape (M, 19).
# Layout:
# v0 (3), edge1 (3), edge2 (3), normal (3), v0_dot_n (1), nx_u (3), nx_v (3)
# v0_dot_n is (M,), others are (M, 3). Reshape v0_dot_n to (M, 1) for concat.
# --- 3. Vectorization & Reduction ---
# We use jax.lax.reduce to perform a parallel-reduction of the intersection check.
# We pass the pre-computed triangle data as operands to the reduction.
# To avoid the "0-d indexing" error encountered with packed arrays, we pass the
# triangle properties as a *tuple of arrays* (M, ...).
# The reduction iterates over axis 0 (triangles), passing slices (single triangle properties) to the body.
# v0_dot_n is (M,), others are (M, 3) or (M, )
# We pass them as severeal separate operands to reduce.
# To reduce N operands, we must have N accumulators.
# We only care about the first accumulator (the hit boolean).
# The others are dummy accumulators to satisfy JAX reduction structure.
# Reshape v0_dot_n to (M, 1) to match (M, 3) slicing behavior more closely
v0_dot_n_2d = v0_dot_n[:, None]
# v0_dot_n is (M,), others are (M, 3) or (M, )
# jax.lax.reduce requires accumulators to be strictly scalar (0-d).
# Since we want to pass vectors (M, 3) and accumulate dummy vectors (3,),
# JAX rejects this.
# We must scalarize EVERYTHING.
# We split all (M, 3) arrays into 3x (M,) arrays.
# Helper to split (M, 3) -> (M,), (M,), (M,)
def split_vec(arr):
return (arr[:, 0], arr[:, 1], arr[:, 2])
# Helper to reconstruct (3,) from components
def stack_vec(c0, c1, c2):
return jnp.array([c0, c1, c2])
# 1. Prepare scalar operands
num_tris = v0.shape[0]
dummy_bool = jnp.zeros((num_tris,), dtype=bool)
# Split triangle vectors
op_v0 = split_vec(v0)
op_e1 = split_vec(edge1)
op_e2 = split_vec(edge2)
op_n = split_vec(tri_normal)
op_v0n = (v0_dot_n,) # Scalar array (M,) tuple
op_nxu = split_vec(nx_u)
op_nxv = split_vec(nx_v)
# Combine triangle operands (tuple of 1 + 3*6 + 1 = 20 arrays)
base_operands = (dummy_bool,) + op_v0 + op_e1 + op_e2 + op_n + op_v0n + op_nxu + op_nxv
def check_any_hit(r_o, r_d):
"""
Checks if ray (r_o, r_d) hits any triangle using jax.lax.reduce.
(Parallel reduction over Triangles).
"""
# Broadcast ray data to be operands (M, 3) then split
r_o_b = jnp.broadcast_to(r_o, (num_tris, 3))
r_d_b = jnp.broadcast_to(r_d, (num_tris, 3))
op_ro = split_vec(r_o_b)
op_rd = split_vec(r_d_b)
# All operands: 20 + 3 + 3 = 26 arrays of shape (M,)
all_operands = base_operands + op_ro + op_rd
# Inits: 26 scalars.
# 0: False
# 1-25: 0.0 (float)
init_vals = (False,) + (jnp.array(0.0),) * 25
# Body function: (acc_tuple, input_tuple) -> acc_tuple
def reduce_body(acc_seq, input_seq):
acc_hit = acc_seq[0]
# Reconstruct vectors from scalar inputs
# Input map:
# 0: dummy
# 1-3: v0
# 4-6: e1
# 7-9: e2
# 10-12: n
# 13: v0n
# 14-16: nxu
# 17-19: nxv
# 20-22: ro
# 23-25: rd
tv0 = stack_vec(input_seq[1], input_seq[2], input_seq[3])
te1 = stack_vec(input_seq[4], input_seq[5], input_seq[6])
te2 = stack_vec(input_seq[7], input_seq[8], input_seq[9])
tn = stack_vec(input_seq[10], input_seq[11], input_seq[12])
tv0n = input_seq[13] # Scalar
tnx_u = stack_vec(input_seq[14], input_seq[15], input_seq[16])
tnx_v = stack_vec(input_seq[17], input_seq[18], input_seq[19])
tr_o = stack_vec(input_seq[20], input_seq[21], input_seq[22])
tr_d = stack_vec(input_seq[23], input_seq[24], input_seq[25])
# --- Scalar Triple Product Algebra ---
det = -jnp.dot(tr_d, tn)
inv_det = 1.0 / det
ray_cross = jnp.cross(tr_o, tr_d)
t_num = jnp.dot(tr_o, tn) - tv0n
u_num = jnp.dot(ray_cross, te2) - jnp.dot(tr_d, tnx_u)
v_num = jnp.dot(tr_d, tnx_v) - jnp.dot(ray_cross, te1)
# --- Intersection Checks ---
u = u_num * inv_det
v = v_num * inv_det
t = t_num * inv_det
not_parallel = (det > epsilon) | (det < -epsilon)
hit = (
not_parallel &
(u >= 0.0) & (u <= 1.0) &
(v >= 0.0) & ((u + v) <= 1.0) &
(t > 0.0)
)
# Propagate hit accumulator with OR
# Propagate other accumulators unchanged (dummies)
return (acc_hit | hit,) + acc_seq[1:]
# Reduce over dimension 0 of the operands (M triangles)
result_seq = jax.lax.reduce(all_operands, init_vals, reduce_body, (0,))
# We only want the first result (hit boolean)
return result_seq[0]
# Map the reduction over all rays
# tri_operands are captured from closure (broadcasted automatically by vmap).
hits_per_ray = jax.vmap(check_any_hit)(ray_origins, ray_directions)
return hits_per_ray
@eqx.filter_jit
def _ray_intersect_any_triangle(
ray_origin: Float[ArrayLike, "3"],
ray_direction: Float[ArrayLike, "3"],
triangle_vertices: Float[ArrayLike, "num_triangles 3 3"],
active_triangles: Bool[ArrayLike, "num_triangles"] | None = None,
*,
epsilon: Float[ArrayLike, ""] | None = None,
hit_tol: Float[ArrayLike, ""] | None = None,
smoothing_factor: Float[ArrayLike, ""] | None = None,
) -> Bool[Array, " "] | Float[Array, " "]: ...that implement the reduction logic on a sequence of triangles. This function should additionally remove unused variables, and implement the logic for smoothing (if @eqx.filter_jit
def rays_intersect_any_triangle(
ray_origins: Float[ArrayLike, "*#batch 3"],
ray_directions: Float[ArrayLike, "*#batch 3"],
triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"],
active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None,
*,
epsilon: Float[ArrayLike, ""] | None = None,
hit_tol: Float[ArrayLike, ""] | None = None,
smoothing_factor: Float[ArrayLike, ""] | None = None,
) -> Bool[Array, " *batch"] | Float[Array, " *batch"]:simply by wrapping def test_oom_stress():
print("\n--- Running OOM Stress Test (1M Rays x 1M Triangles) ---")
print("Note: This test effectively requests 1 Trillion interaction checks.")
print("Without chunking (scan/loop), this is expected to stress memory significantly.")
# 100,000 rays and triangles
N = 100_000
M = 100_000
# Create dummy tiny arrays to avoid allocating input memory,
# but broadcast them to simulate the workload shape.
# We use jnp.broadcast_to to create the "shapes" without the full data cost initially.
# We use small actual arrays and broadcast them to test the logic's memory limit
# without running out of RAM just creating the inputs.
ray_o = jnp.zeros((1, 3))
ray_d = jnp.array([[0.0, 0.0, 1.0]])
tri = jnp.array([[[0.0, 0.0, 5.0], [1.0, 0.0, 5.0], [0.0, 1.0, 5.0]]])
big_rays_o = jnp.broadcast_to(ray_o, (N, 3))
big_rays_d = jnp.broadcast_to(ray_d, (N, 3))
big_tris = jnp.broadcast_to(tri, (M, 3, 3))
print(f"Input shapes virtualized: Rays {big_rays_o.shape}, Tris {big_tris.shape}")
try: # We expected OOM here
vmap_tris = jax.vmap(ray_intersect_triangle, in_axes=(None, None, 0))
# vmap_rays takes (N, 3), (N, 3), (M, 3, 3) -> (N, M)
vmap_rays = jax.vmap(vmap_tris, in_axes=(0, 0, None))
except Exception as e:
out = vmap_rays(big_rays_o, big_rays_d, big_tris).block_until_ready()
try: # We do not expect OOM here
start = time.time()
# We don't block_until_ready immediately to see if graph construction fails
res = rays_intersect_any_triangle(big_rays_o, big_rays_d, big_tris)
res.block_until_ready()
print(f"SUCCESS: Computed 1T interactions in {time.time() - start:.2f}s")
print(f"Result shape: {res.shape}")
except Exception as e:
print("FAILURE: Test crashed as expected with massive allocation constraints.")
print(f"Error: {e}") |
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## main #393 +/- ##
===========================================
- Coverage 84.95% 34.20% -50.76%
===========================================
Files 32 4 -28
Lines 3051 538 -2513
===========================================
- Hits 2592 184 -2408
+ Misses 459 354 -105 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
Description
Tests fail because they pass
batch_size=11torays_intersect_any_triangle, but the newjax.lax.reduceimplementation doesn't accept this parameter.Changes:
batch_size: int | None = Nonewithnoqa: ARG001to accept but ignore the parameterepsilonexplicitly and removed**kwargsfor cleaner type hintsBehavior:
The new implementation uses
jax.lax.reducewhich processes all triangles without batching, making this parameter obsolete. Retained for backward compatibility only.Checklist
Note to reviewers
Parameter is marked deprecated but not removed to avoid breaking existing test code. Consider removing in next major version.
Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.