diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index b27e6497..3995dda9 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -292,6 +292,11 @@ class SimpleArrayMixinSearch } return max_index; } + + SimpleArray argwhere() const; + + SimpleArray argwhere(std::function const & condition) const; + A where(std::function const & condition, A const & other) const; }; /* end class SimpleArrayMixinSearch */ } /* end namespace detail */ @@ -1032,6 +1037,69 @@ A detail::SimpleArrayMixinSort::take_along_axis_simd(SimpleArray const return ret; } +template +SimpleArray detail::SimpleArrayMixinSearch::argwhere() const +{ + auto default_condition = [](value_type const & x) + { + return x != value_type(); + }; + return this->argwhere(default_condition); +} + +template +SimpleArray detail::SimpleArrayMixinSearch::argwhere(std::function const & condition) const +{ + auto athis = static_cast(this); + uint64_t const array_size = athis->size(); + uint64_t const array_dim = athis->ndim(); + std::vector indices; + for (uint64_t i = 0; i < array_size; ++i) + { + if (condition(athis->data(i))) + { + indices.push_back(i); + } + } + + SimpleArray coordinates(std::vector{indices.size(), array_dim}); + auto coord = coordinates.begin(); + + std::vector product_of_dims(array_dim, 1); + for (size_t i = 1; i < array_dim; ++i) product_of_dims[i] = product_of_dims[i - 1] * athis->shape(i); + for (auto const & index : indices) + { + uint64_t remaining_index = index; + for (int i = array_dim - 1; i >= 0; --i) + { + *coord = remaining_index / product_of_dims[i]; + remaining_index = remaining_index % product_of_dims[i]; + ++coord; + } + } + return coordinates; +} + +template +A detail::SimpleArrayMixinSearch::where(std::function const & condition, A const & other) const +{ + auto athis = static_cast(this); + uint64_t const array_size = athis->size(); + A ret(athis->shape()); + for (uint64_t i = 0; i < array_size; ++i) + { + if (condition(athis->data(i))) + { + ret.data(i) = athis->data(i); + } + else + { + ret.data(i) = other.data(i); + } + } + return ret; +} + template using is_simple_array = std::is_same< std::remove_reference_t, diff --git a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp index b8e643ab..cd870e7b 100644 --- a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp +++ b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp @@ -27,6 +27,7 @@ * POSSIBILITY OF SUCH DAMAGE. */ #include +#include #include #include diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index f0477f1e..33684ca4 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -313,9 +313,23 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray (*this) .def("argmin", &wrapped_type::argmin) .def("argmax", &wrapped_type::argmax) + .def( + "argwhere", + [](wrapped_type const & self, std::function const & condition) + { + if (!condition) + { + return py::cast(self.argwhere()); + } + return py::cast(self.argwhere(condition)); + }, + py::arg("condition") = py::none()) + .def( + "where", + [](wrapped_type const & self, std::function const & condition, wrapped_type const & other) + { return py::cast(self.where(condition, other)); }) // ; - return *this; } }; /* end class WrapSimpleArray */ diff --git a/tests/test_buffer.py b/tests/test_buffer.py index b9738d3f..917ddd9b 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1021,6 +1021,85 @@ def test_argminmax(self): self.assertEqual(narr.argmin(), sarr.argmin()) self.assertEqual(narr.argmax(), sarr.argmax()) + def test_argwhere(self): + # test 1-D data + data = [1, 3, 5, 7, 9] + narr = np.array(data, dtype='uint64') + sarr = modmesh.SimpleArrayUint64(array=narr) + + ret_np = np.argwhere(narr > 5) + ret_sa = sarr.argwhere(lambda x: x > 5) + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(ret_sa.shape[0]): + for j in range(ret_sa.shape[1]): + self.assertEqual(ret_np[i, j], ret_sa[i, j]) + + # test N-D data + data = [[-1.3, -4.8, 1.5, 0.3, 7.1], [2.5, 4.8, -0.1, 9.4, 7.6]] + narr = np.array(data, dtype='float64') + sarr = modmesh.SimpleArrayFloat64(array=narr) + + ret_np = np.argwhere(narr <= 2.8) + ret_sa = sarr.argwhere(lambda x: x <= 2.8) + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(ret_sa.shape[0]): + for j in range(ret_sa.shape[1]): + self.assertEqual(ret_np[i, j], ret_sa[i, j]) + + # test N-D data + data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[1, 3], [5, 7]]] + narr = np.array(data, dtype='int64') + sarr = modmesh.SimpleArrayInt64(array=narr) + + ret_np = np.argwhere(narr == 5) + ret_sa = sarr.argwhere(lambda x: x == 5) + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(ret_sa.shape[0]): + for j in range(ret_sa.shape[1]): + self.assertEqual(ret_np[i, j], ret_sa[i, j]) + + # default case: non-zero + ret_np = np.argwhere(narr) + ret_sa = sarr.argwhere() + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(ret_sa.shape[0]): + for j in range(ret_sa.shape[1]): + self.assertEqual(ret_np[i, j], ret_sa[i, j]) + + def test_where(self): + # test 1-D data + data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + narr_1 = np.array(data, dtype='uint64') + narr_2 = narr_1 * 10 + sarr_1 = modmesh.SimpleArrayUint64(array=narr_1) + sarr_2 = modmesh.SimpleArrayUint64(array=narr_2) + + ret_np = np.where(narr_1 < 5, narr_1, narr_2) + ret_sa = sarr_1.where(lambda x: x < 5, sarr_2) + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(len(ret_np)): + self.assertEqual(ret_np[i], ret_sa[i]) + + # test N-D data + data = [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10], [1, 10, 1, 10, 1]] + narr_1 = np.array(data, dtype='float64') + narr_2 = narr_1 * 10 + sarr_1 = modmesh.SimpleArrayFloat64(array=narr_1) + sarr_2 = modmesh.SimpleArrayFloat64(array=narr_2) + + ret_np = np.where(narr_1 == 1, narr_1, narr_2) + ret_sa = sarr_1.where(lambda x: x == 1, sarr_2) + + self.assertEqual(ret_np.shape, ret_sa.shape) + for i in range(ret_np.shape[0]): + for j in range(ret_np.shape[1]): + self.assertEqual(ret_np[i, j], ret_sa[i, j]) + class SimpleArrayPlexTC(unittest.TestCase):