Skip to content

Feature: optional forward diff mode#14

Open
mattia-spider wants to merge 1 commit into
camml-lab:developfrom
mattia-spider:feature/forward-mode-diff
Open

Feature: optional forward diff mode#14
mattia-spider wants to merge 1 commit into
camml-lab:developfrom
mattia-spider:feature/forward-mode-diff

Conversation

@mattia-spider

Copy link
Copy Markdown
Collaborator

Add optional forward-mode autodiff to diff() for cases where dim(input) << dim(output).

Selectable via mode="fwd" (default "rev", so existing uses like force computation are unchanged). Forward mode drastically reduces memory overhead when differentiating w.r.t. a low-dimensional input (e.g. the 3-component external field for NMR tensors), since it needs O(n_input) passes instead of O(n_output).

Memory impact, same architecture on qm9-nmr:

80GB H100 (with batch size 8): peak_MiB: 35477 (rev) → peak_MiB: 4849 (fwd)

8GB GPU: max batch 3 (rev) → 11 (fwd)
80GB H100: max batch 8 (rev) → 40 (fwd)

Correctness: fwd and rev give the same Jacobian, max abs diff ~3e-6 (float32), ~1e-15 (float64).
TODO (lines 297-303): forward mode breaks e3nn IrrepsArray's shape invariant (jacfwd appends the tangent axis, colliding with the irreps dim on rewrap). Worked around with value = base.as_array(value), active only in fwd mode. Proper fix belongs upstream in e3nn; remove workaround once resolved.

@mattia-spider mattia-spider requested a review from muhrin June 19, 2026 13:12
@codecov-commenter

codecov-commenter commented Jun 19, 2026

Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tensorial/gcnn/_diff.py 87.50% 1 Missing ⚠️
❗ Your organization needs to install the Codecov GitHub app to enable full functionality.
Files with missing lines Coverage Δ
src/tensorial/gcnn/_diff.py 79.55% <87.50%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@muhrin

muhrin commented Jun 20, 2026

Copy link
Copy Markdown
Member

Great! Could you try removing your x.array workaround for forward mode together with my e3nn-jax version: https://github.com/muhrin/e3nn-jax/tree/develop

I think I've fixed the bug affecting the Jacobian.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants