-
Notifications
You must be signed in to change notification settings - Fork 152
Open
Description
When using Enzyme’s forward-mode API (__enzyme_fwddiff) to compute gradients through Eigen::Matrix<double, N, N>::inverse() for fixed-size matrices, compilation fails when N >= 4 (reproduced with N = 4). The same code compiles and runs correctly for N = 2 and N = 3.
Minimal reproducible example
#include <iostream>
#include <enzyme/enzyme>
#include <Eigen/Dense>
#include <vector>
#include <random>
int enzyme_dup;
int enzyme_dupnoneed;
int enzyme_out;
int enzyme_const;
template <typename return_type, typename... T>
return_type __enzyme_fwddiff(void*, T...);
template <typename return_type, typename... T>
return_type __enzyme_autodiff(void*, T...);
const size_t input_size = 2;
const size_t matrix_size = 4; // change to 2 or 3 works fine
using Matrix = Eigen::Matrix<double, matrix_size, matrix_size>;
void enzyme_func_matrix(Matrix* mats, double& sol) {
Matrix M0;
M0 = mats[0].inverse();
sol = M0.sum();
}
template <auto F>
void compute_J_matrix_forward(
std::vector<std::vector<std::vector<double>>> inputs,
std::vector<std::vector<std::vector<double>>>& J_matrix_forward
) {
J_matrix_forward.assign(input_size,
std::vector<std::vector<double>>(
matrix_size,
std::vector<double>(matrix_size, 0.0)
)
);
std::vector<Matrix> mats(input_size), dmats(input_size);
for (size_t mat_idx = 0; mat_idx < input_size; mat_idx++) {
for (size_t i = 0; i < matrix_size; i++) {
for (size_t j = 0; j < matrix_size; j++) {
mats[mat_idx](i, j) = inputs[mat_idx][i][j];
}
}
}
auto zero_gradients = [&]() {
for (auto& M : dmats) {
for (int i = 0; i < (int)matrix_size; i++) {
for (int j = 0; j < (int)matrix_size; j++) {
M(i, j) = 0.0;
}
}
}
};
double result, dresult;
for (size_t mat_idx = 0; mat_idx < input_size; mat_idx++) {
for (size_t i = 0; i < matrix_size; i++) {
for (size_t j = 0; j < matrix_size; j++) {
zero_gradients();
dmats[mat_idx](i, j) = 1.0;
__enzyme_fwddiff<void>((void*)F,
enzyme_dup, mats.data(), dmats.data(),
enzyme_dup, &result, &dresult);
J_matrix_forward[mat_idx][i][j] = dresult;
}
}
}
}
void generate_inputs(size_t matrix_size, std::vector<std::vector<std::vector<double>>>& inputs) {
inputs.assign(2,
std::vector<std::vector<double>>(
matrix_size,
std::vector<double>(matrix_size)
)
);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<double> dist(0.0, 1.0);
for (auto& mat : inputs) {
for (auto& row : mat) {
for (auto& val : row) {
val = dist(gen);
}
}
}
}
int main() {
std::vector<std::vector<std::vector<double>>> J_matrix_forward;
std::vector<std::vector<std::vector<double>>> inputs;
generate_inputs(matrix_size, inputs);
compute_J_matrix_forward<enzyme_func_matrix>(inputs, J_matrix_forward);
for (size_t input_idx = 0; input_idx < input_size; input_idx++) {
for (size_t i = 0; i < matrix_size; i++) {
for (size_t j = 0; j < matrix_size; j++) {
std::cout << "J_matrix_forward[" << input_idx << "]["
<< i << "][" << j << "] = "
<< J_matrix_forward[input_idx][i][j] << "\n";
}
}
}
return 0;
}Compilation error log
In file included from ./main.cpp:3:
In file included from ../../../build/debug/include/Eigen/Dense:1:
In file included from ../../../build/debug/include/Eigen/Core:205:
../../../build/debug/include/Eigen/src/Core/arch/SSE/PacketMath.h:423:103: error: Enzyme: ; Function Attrs: mustprogress noinline optnone willreturn uwtable
define linkonce_odr dso_local noundef <2 x double> @preprocess__ZN5Eigen8internal4pxorIDv2_dEET_RKS3_S5_(ptr noundef nonnull align 16 dereferenceable(16) %a, ptr noundef nonnull align 16 dereferenceable(16) %b) #31 !dbg !15509 {
entry:
call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511
call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511
%0 = load <2 x double>, ptr %a, align 16, !dbg !15513
%1 = load <2 x double>, ptr %b, align 16, !dbg !15514
%2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !15515
%3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !15515
%xor.i = xor <2 x i64> %2, %3, !dbg !15515
%4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !15515
ret <2 x double> %4, !dbg !15516
}
constantarg[ptr %a] = 0 type: {[-1]:Pointer, [-1,-1]:Float@double} - vals: {}
constantarg[ptr %b] = 0 type: {[-1]:Pointer} - vals: {}
constantinst[ call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
constantinst[ call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
constantinst[ %0 = load <2 x double>, ptr %a, align 16, !dbg !7229] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %1 = load <2 x double>, ptr %b, align 16, !dbg !7230] = 0 val:0 type: {}
constantinst[ %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !7231] = 0 val:0 type: {}
constantinst[ %xor.i = xor <2 x i64> %2, %3, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ ret <2 x double> %4, !dbg !7232] = 1 val:1 type: {}
cannot handle unknown binary operator: %xor.i = xor <2 x i64> %2, %3, !dbg !7231
template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
^
../../../build/debug/include/Eigen/src/Core/arch/SSE/PacketMath.h:423:103: error: Enzyme: ; Function Attrs: mustprogress noinline optnone willreturn uwtable
define linkonce_odr dso_local noundef <2 x double> @preprocess__ZN5Eigen8internal4pxorIDv2_dEET_RKS3_S5_(ptr noundef nonnull align 16 dereferenceable(16) %a, ptr noundef nonnull align 16 dereferenceable(16) %b) #31 !dbg !15509 {
entry:
call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511
call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511
%0 = load <2 x double>, ptr %a, align 16, !dbg !15513
%1 = load <2 x double>, ptr %b, align 16, !dbg !15514
%2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !15515
%3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !15515
%xor.i = xor <2 x i64> %2, %3, !dbg !15515
%4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !15515
ret <2 x double> %4, !dbg !15516
}
constantarg[ptr %a] = 0 type: {[-1]:Pointer, [-1,-1]:Float@double} - vals: {}
constantarg[ptr %b] = 0 type: {[-1]:Pointer, [-1,0]:Integer} - vals: {}
constantinst[ call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
constantinst[ call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
constantinst[ %0 = load <2 x double>, ptr %a, align 16, !dbg !7229] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %1 = load <2 x double>, ptr %b, align 16, !dbg !7230] = 0 val:0 type: {[0]:Integer}
constantinst[ %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[0]:Integer}
constantinst[ %xor.i = xor <2 x i64> %2, %3, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
constantinst[ ret <2 x double> %4, !dbg !7232] = 1 val:1 type: {}
cannot handle unknown binary operator: %xor.i = xor <2 x i64> %2, %3, !dbg !7231
2 errors generated.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels