-
Notifications
You must be signed in to change notification settings - Fork 152
Open
Description
I'm trying to register a custom forward-mode derivative using __enzyme_register_derivative in C++, building on the customfwd test/example from the Enzyme repository.
Depending on the return type of the differentiated function, I encounter different assertion failures during compilation:
int: Assertion failed: (returnUsed), function CreateForwardDiff, file EnzymeLogic.cpp, line 4790.double: Assertion failed: (Ty && "Invalid GetElementPtrInst indices for type!"), function checkGEPType, file Instructions.h, line 942.
Here is the minimal reproducible example:
#include <cassert>
#include <enzyme/enzyme>
// Assertion failed: (returnUsed), function CreateForwardDiff, file EnzymeLogic.cpp, line 4790.
// using T = int;
// Assertion failed: (Ty && "Invalid GetElementPtrInst indices for type!"), function checkGEPType, file Instructions.h, line 942.
using T = double;
T __enzyme_fwddiff(T (*)(T *), T *, T *);
T square(T *x) { return (*x) * (*x); }
int derivative = 0;
T derivative_square(T *x, T *dx) { derivative++; return (T) 100; }
void* __enzyme_register_derivative_square[] =
{
(void*)square,
(void*)derivative_square,
};
T dsquare(T *x, T *dx) { return __enzyme_fwddiff<T>((void*)square, enzyme_dup, x, dx); }
int main()
{
T x = 3, dx = 1;
T res = dsquare(&x, &dx);
assert(derivative == 1);
assert(res == 100);
return EXIT_SUCCESS;
}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels