Skip to content

Support Mixing AD Frameworks for LogDensityProblems and the objective #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 111 commits into from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Jul 7, 2025

This PR enables mixing AD frameworks for the target LogDensityProblem and differentiating through the variational objective. This is enabled by defining a custom rrule through ChainRulesCore. As such, we are restricted to AD frameworks that support importing rrules. However, we still support directly differentiating through LogDensityProblems.logdensity if the target LogDensityProblem only has zeroth-order capability.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 3bb03f9 Previous: d160060 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 3657107374 ns 3911593378.5 ns 0.93
normal/RepGradELBO + STL/meanfield/ReverseDiff 1102349821 ns 1126577916 ns 0.98
normal/RepGradELBO + STL/meanfield/Mooncake 1134389980 ns 1188520969 ns 0.95
normal/RepGradELBO + STL/fullrank/Zygote 3638817421.5 ns 3851083566.5 ns 0.94
normal/RepGradELBO + STL/fullrank/ReverseDiff 1579876196 ns 1611527605.5 ns 0.98
normal/RepGradELBO + STL/fullrank/Mooncake 1214828309 ns 1248235339 ns 0.97
normal/RepGradELBO/meanfield/Zygote 2547816176 ns 2738634121.5 ns 0.93
normal/RepGradELBO/meanfield/ReverseDiff 756326907 ns 770435862 ns 0.98
normal/RepGradELBO/meanfield/Mooncake 1021460616 ns 1062623223 ns 0.96
normal/RepGradELBO/fullrank/Zygote 2585227666.5 ns 2780952483.5 ns 0.93
normal/RepGradELBO/fullrank/ReverseDiff 933675264.5 ns 958044713 ns 0.97
normal/RepGradELBO/fullrank/Mooncake 1099030538 ns 1142753555 ns 0.96
normal + bijector/RepGradELBO + STL/meanfield/Zygote 5284409807 ns 5608097016 ns 0.94
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 2342704841 ns 2448081612 ns 0.96
normal + bijector/RepGradELBO + STL/meanfield/Mooncake 3869237236 ns 3958456262.5 ns 0.98
normal + bijector/RepGradELBO + STL/fullrank/Zygote 5440495726 ns 5573088994 ns 0.98
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 3027567497 ns 3058706489 ns 0.99
normal + bijector/RepGradELBO + STL/fullrank/Mooncake 4033537126 ns 4098676416.5 ns 0.98
normal + bijector/RepGradELBO/meanfield/Zygote 4188547956.5 ns 4289915836.5 ns 0.98
normal + bijector/RepGradELBO/meanfield/ReverseDiff 2011523681 ns 2073166583 ns 0.97
normal + bijector/RepGradELBO/meanfield/Mooncake 3719413465.5 ns 3800475334 ns 0.98
normal + bijector/RepGradELBO/fullrank/Zygote 4302378001.5 ns 4419557304 ns 0.97
normal + bijector/RepGradELBO/fullrank/ReverseDiff 2287039896 ns 2342946249 ns 0.98
normal + bijector/RepGradELBO/fullrank/Mooncake 3899698655 ns 3965463115.5 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

Red-Portal and others added 2 commits July 30, 2025 17:57
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal
Copy link
Member Author

This PR is superseded by a cleaned-up version #187

@Red-Portal Red-Portal closed this Jul 30, 2025
@yebai yebai deleted the mixed_ad branch August 19, 2025 20:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants