Skip to content

Compilation error differentiating Eigen::Matrix::inverse() with Enzyme forward-mode when matrix size ≥ 4 #2679

@shadow-orange41

Description

@shadow-orange41

When using Enzyme’s forward-mode API (__enzyme_fwddiff) to compute gradients through Eigen::Matrix<double, N, N>::inverse() for fixed-size matrices, compilation fails when N >= 4 (reproduced with N = 4). The same code compiles and runs correctly for N = 2 and N = 3.

Minimal reproducible example
#include <iostream>
#include <enzyme/enzyme>
#include <Eigen/Dense>
#include <vector>
#include <random>

int enzyme_dup;
int enzyme_dupnoneed;
int enzyme_out;
int enzyme_const;

template <typename return_type, typename... T>
return_type __enzyme_fwddiff(void*, T...);

template <typename return_type, typename... T>
return_type __enzyme_autodiff(void*, T...);

const size_t input_size = 2;
const size_t matrix_size = 4; // change to 2 or 3 works fine

using Matrix = Eigen::Matrix<double, matrix_size, matrix_size>;

void enzyme_func_matrix(Matrix* mats, double& sol) {
    Matrix M0;
    M0 = mats[0].inverse();
    sol = M0.sum();
}

template <auto F>
void compute_J_matrix_forward(
    std::vector<std::vector<std::vector<double>>> inputs,
    std::vector<std::vector<std::vector<double>>>& J_matrix_forward
) {
    J_matrix_forward.assign(input_size,
        std::vector<std::vector<double>>(
            matrix_size,
            std::vector<double>(matrix_size, 0.0)
        )
    );

    std::vector<Matrix> mats(input_size), dmats(input_size);

    for (size_t mat_idx = 0; mat_idx < input_size; mat_idx++) {
        for (size_t i = 0; i < matrix_size; i++) {
            for (size_t j = 0; j < matrix_size; j++) {
                mats[mat_idx](i, j) = inputs[mat_idx][i][j];
            }
        }
    }

    auto zero_gradients = [&]() {
        for (auto& M : dmats) {
            for (int i = 0; i < (int)matrix_size; i++) {
                for (int j = 0; j < (int)matrix_size; j++) {
                    M(i, j) = 0.0;
                }
            }
        }
    };

    double result, dresult;
    for (size_t mat_idx = 0; mat_idx < input_size; mat_idx++) {
        for (size_t i = 0; i < matrix_size; i++) {
            for (size_t j = 0; j < matrix_size; j++) {
                zero_gradients();
                dmats[mat_idx](i, j) = 1.0;

                __enzyme_fwddiff<void>((void*)F,
                    enzyme_dup, mats.data(),  dmats.data(),
                    enzyme_dup, &result,     &dresult);

                J_matrix_forward[mat_idx][i][j] = dresult;
            }
        }
    }
}

void generate_inputs(size_t matrix_size, std::vector<std::vector<std::vector<double>>>& inputs) {
    inputs.assign(2,
        std::vector<std::vector<double>>(
            matrix_size,
            std::vector<double>(matrix_size)
        )
    );

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<double> dist(0.0, 1.0);

    for (auto& mat : inputs) {
        for (auto& row : mat) {
            for (auto& val : row) {
                val = dist(gen);
            }
        }
    }
}

int main() {
    std::vector<std::vector<std::vector<double>>> J_matrix_forward;
    std::vector<std::vector<std::vector<double>>> inputs;

    generate_inputs(matrix_size, inputs);
    compute_J_matrix_forward<enzyme_func_matrix>(inputs, J_matrix_forward);

    for (size_t input_idx = 0; input_idx < input_size; input_idx++) {
        for (size_t i = 0; i < matrix_size; i++) {
            for (size_t j = 0; j < matrix_size; j++) {
                std::cout << "J_matrix_forward[" << input_idx << "]["
                          << i << "][" << j << "] = "
                          << J_matrix_forward[input_idx][i][j] << "\n";
            }
        }
    }
    return 0;
}
Compilation error log
In file included from ./main.cpp:3:
In file included from ../../../build/debug/include/Eigen/Dense:1:
In file included from ../../../build/debug/include/Eigen/Core:205:
../../../build/debug/include/Eigen/src/Core/arch/SSE/PacketMath.h:423:103: error: Enzyme: ; Function Attrs: mustprogress noinline optnone willreturn uwtable
define linkonce_odr dso_local noundef <2 x double> @preprocess__ZN5Eigen8internal4pxorIDv2_dEET_RKS3_S5_(ptr noundef nonnull align 16 dereferenceable(16) %a, ptr noundef nonnull align 16 dereferenceable(16) %b) #31 !dbg !15509 {
entry:
  call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511
  call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511
  %0 = load <2 x double>, ptr %a, align 16, !dbg !15513
  %1 = load <2 x double>, ptr %b, align 16, !dbg !15514
  %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !15515
  %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !15515
  %xor.i = xor <2 x i64> %2, %3, !dbg !15515
  %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !15515
  ret <2 x double> %4, !dbg !15516
}

 constantarg[ptr %a] = 0 type: {[-1]:Pointer, [-1,-1]:Float@double} - vals: {}
 constantarg[ptr %b] = 0 type: {[-1]:Pointer} - vals: {}
 constantinst[  call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
 constantinst[  call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
 constantinst[  %0 = load <2 x double>, ptr %a, align 16, !dbg !7229] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %1 = load <2 x double>, ptr %b, align 16, !dbg !7230] = 0 val:0 type: {}
 constantinst[  %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !7231] = 0 val:0 type: {}
 constantinst[  %xor.i = xor <2 x i64> %2, %3, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  ret <2 x double> %4, !dbg !7232] = 1 val:1 type: {}
cannot handle unknown binary operator:   %xor.i = xor <2 x i64> %2, %3, !dbg !7231

template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
                                                                                                      ^
../../../build/debug/include/Eigen/src/Core/arch/SSE/PacketMath.h:423:103: error: Enzyme: ; Function Attrs: mustprogress noinline optnone willreturn uwtable
define linkonce_odr dso_local noundef <2 x double> @preprocess__ZN5Eigen8internal4pxorIDv2_dEET_RKS3_S5_(ptr noundef nonnull align 16 dereferenceable(16) %a, ptr noundef nonnull align 16 dereferenceable(16) %b) #31 !dbg !15509 {
entry:
  call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511
  call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511
  %0 = load <2 x double>, ptr %a, align 16, !dbg !15513
  %1 = load <2 x double>, ptr %b, align 16, !dbg !15514
  %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !15515
  %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !15515
  %xor.i = xor <2 x i64> %2, %3, !dbg !15515
  %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !15515
  ret <2 x double> %4, !dbg !15516
}

 constantarg[ptr %a] = 0 type: {[-1]:Pointer, [-1,-1]:Float@double} - vals: {}
 constantarg[ptr %b] = 0 type: {[-1]:Pointer, [-1,0]:Integer} - vals: {}
 constantinst[  call void @llvm.dbg.value(metadata ptr %a, metadata !15510, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
 constantinst[  call void @llvm.dbg.value(metadata ptr %b, metadata !15512, metadata !DIExpression()), !dbg !15511] = 1 val:1 type: {}
 constantinst[  %0 = load <2 x double>, ptr %a, align 16, !dbg !7229] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %1 = load <2 x double>, ptr %b, align 16, !dbg !7230] = 0 val:0 type: {[0]:Integer}
 constantinst[  %2 = bitcast <2 x double> %0 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %3 = bitcast <2 x double> %1 to <2 x i64>, !dbg !7231] = 0 val:0 type: {[0]:Integer}
 constantinst[  %xor.i = xor <2 x i64> %2, %3, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  %4 = bitcast <2 x i64> %xor.i to <2 x double>, !dbg !7231] = 0 val:0 type: {[-1]:Float@double}
 constantinst[  ret <2 x double> %4, !dbg !7232] = 1 val:1 type: {}
cannot handle unknown binary operator:   %xor.i = xor <2 x i64> %2, %3, !dbg !7231

2 errors generated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions