-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Labels
slangtorch_parity_polishImprovements beyond slangtorch_parityImprovements beyond slangtorch_parity
Description
Summary
slangpy Function has .bwds() for backward-mode differentiation but no .fwds() for forward-mode differentiation (Jacobian-vector products).
Expected Behavior
slangtorch provides both:
.fwd()- forward-mode differentiation (computes Jacobian-vector products).bwd()- backward-mode differentiation (computes vector-Jacobian products)
Actual Behavior
slangpy only provides:
.bwds()- backward-mode differentiation
No forward-mode equivalent exists.
Use Case
Forward-mode differentiation is useful for:
- Computing directional derivatives
- Jacobian-vector products without materializing the full Jacobian
- Cases where output dimension >> input dimension (forward-mode is more efficient)
Reproduction
square_diff.slang:
[Differentiable]
void square(no_diff int tid, DiffTensorView input, DiffTensorView output)
{
if (tid >= input.size(0)) return;
output.store(tid, input[tid] * input[tid]);
}test_fwd.py:
import torch
import slangpy as spy
from slangpy.torchintegration import diff_pair
device = spy.Device(type=spy.DeviceType.cuda)
module = spy.Module.load_from_source(device, "square_diff.slang")
X = torch.tensor([3.0, 4.0], device="cuda")
Y = torch.zeros_like(X)
tid = torch.arange(2, device="cuda", dtype=torch.int32)
# Backward works
module.square.bwds(tid=tid, input=diff_pair(X, torch.zeros_like(X)),
output=diff_pair(Y, torch.ones_like(Y)))
# Forward doesn't exist
module.square.fwds(...) # AttributeError: no .fwds methodslangtorch equivalent
# slangtorch forward-mode diff
m.square.fwd(input=(X, dX), output=(Y, dY)).launchRaw(...)Impact
Medium - Forward-mode diff is less commonly used than backward-mode, but important for certain applications (sensitivity analysis, directional derivatives).
Environment
- slangpy 0.40.1
- CUDA 12.x
- PyTorch 2.10.0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
slangtorch_parity_polishImprovements beyond slangtorch_parityImprovements beyond slangtorch_parity