-
Notifications
You must be signed in to change notification settings - Fork 135
Description
// amgcl_solve.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include
#include
#include <amgcl/backend/builtin.hpp>
//#include <amgcl/adapter/block_matrix.hpp>
#include <amgcl/adapter/crs_tuple.hpp>
#include <amgcl/profiler.hpp>
#include <amgcl/make_solver.hpp>
#include <amgcl/amg.hpp>
#include <amgcl/coarsening/smoothed_aggregation.hpp>
#include <amgcl/relaxation/spai0.hpp>
#include <amgcl/solver/bicgstab.hpp>
#include
namespace py = pybind11;
using namespace std;
//
std::vector solve(const vector& row, const vector& col, const vector& val, const vector& fb) {
//std::vector output;
long long unsigned int b_length = fb.size();
std::vector x(b_length, 0.0);
//long long unsigned int row_size = row.size() - 1;
//
typedef amgcl::backend::builtin<double> Backend;
typedef amgcl::make_solver<
amgcl::amg<
Backend,
amgcl::coarsening::smoothed_aggregation,
amgcl::relaxation::spai0
>,
amgcl::solver::bicgstab<Backend>
> Solver;
std::cout << "row :"<< row.size()<<"n" <<std::endl ;
std::cout << "col :"<< col.size()<<"n" <<std::endl ;
std::cout << "val :"<< val.size()<<"n" <<std::endl ;
std::cout << "n :"<< b_length<<"n" <<std::endl ;
Solver solver(std::tie(b_length, row, col, val));
//
std::tie(std::ignore, std::ignore) = solver(fb, x);
return x;
}
// C++ Python
PYBIND11_MODULE(amgcl_solve, m) {
m.def("solve", []( py::array_t row, py::array_t col, py::array_t val, py::array_t fb) {
// 将 NumPy 数组转换为 std::vector
py::buffer_info buf1 = row.request();
int* ptr1 = static_cast<int*>(buf1.ptr);
std::vector row_vector(ptr1, ptr1 + buf1.size);
py::buffer_info buf2 = col.request();
int* ptr2 = static_cast<int*>(buf2.ptr);
std::vector<int> col_vector(ptr2, ptr2 + buf2.size);
py::buffer_info buf3 = val.request();
double* ptr3 = static_cast<double*>(buf3.ptr);
std::vector<double> val_vector(ptr3, ptr3 + buf3.size);
py::buffer_info buf4 = fb.request();
double* ptr4 = static_cast<double*>(buf4.ptr);
std::vector<double> fb_vector(ptr4, ptr4 + buf4.size);
//
std::vector<double> output_vector = process_array(row_vector, col_vector, val_vector, fb_vector);
//
py::array_t<double> output = py::array_t<double>(output_vector.size());
py::buffer_info output_buf = output.request();
double* output_ptr = static_cast<double*>(output_buf.ptr);
for (size_t i = 0; i < output_vector.size(); ++i) {
output_ptr[i] = output_vector[i];
}
return output;
});
}