Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ class SimpleArrayMixinSearch
}
return max_index;
}

SimpleArray<uint64_t> argwhere() const;

SimpleArray<uint64_t> argwhere(std::function<bool(value_type const &)> const & condition) const;
A where(std::function<bool(value_type const &)> const & condition, A const & other) const;
}; /* end class SimpleArrayMixinSearch */

} /* end namespace detail */
Expand Down Expand Up @@ -1032,6 +1037,69 @@ A detail::SimpleArrayMixinSort<A, T>::take_along_axis_simd(SimpleArray<I> const
return ret;
}

template <typename A, typename T>
SimpleArray<uint64_t> detail::SimpleArrayMixinSearch<A, T>::argwhere() const
{
auto default_condition = [](value_type const & x)
{
return x != value_type();
};
return this->argwhere(default_condition);
}

template <typename A, typename T>
SimpleArray<uint64_t> detail::SimpleArrayMixinSearch<A, T>::argwhere(std::function<bool(value_type const &)> const & condition) const
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Use an std::vector to store the indices of elements that satisfy the condition.
  • Creating the output coordinates and filling the dimensions's information.

{
auto athis = static_cast<A const *>(this);
uint64_t const array_size = athis->size();
uint64_t const array_dim = athis->ndim();
std::vector<uint64_t> indices;
for (uint64_t i = 0; i < array_size; ++i)
{
if (condition(athis->data(i)))
{
indices.push_back(i);
}
}

SimpleArray<uint64_t> coordinates(std::vector<size_t>{indices.size(), array_dim});
auto coord = coordinates.begin();

std::vector<uint64_t> product_of_dims(array_dim, 1);
for (size_t i = 1; i < array_dim; ++i) product_of_dims[i] = product_of_dims[i - 1] * athis->shape(i);
for (auto const & index : indices)
{
uint64_t remaining_index = index;
for (int i = array_dim - 1; i >= 0; --i)
{
*coord = remaining_index / product_of_dims[i];
remaining_index = remaining_index % product_of_dims[i];
++coord;
}
}
return coordinates;
}

template <typename A, typename T>
A detail::SimpleArrayMixinSearch<A, T>::where(std::function<bool(value_type const &)> const & condition, A const & other) const
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Should output the same shape as input.
  • If the condition is met, fill with the original value; otherwise, use the alternative.

{
auto athis = static_cast<A const *>(this);
uint64_t const array_size = athis->size();
A ret(athis->shape());
for (uint64_t i = 0; i < array_size; ++i)
{
if (condition(athis->data(i)))
{
ret.data(i) = athis->data(i);
}
else
{
ret.data(i) = other.data(i);
}
}
return ret;
}

template <typename S>
using is_simple_array = std::is_same<
std::remove_reference_t<S>,
Expand Down
1 change: 1 addition & 0 deletions cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>

#include <modmesh/buffer/buffer.hpp>
#include <modmesh/math/math.hpp>
Expand Down
16 changes: 15 additions & 1 deletion cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,23 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray
(*this)
.def("argmin", &wrapped_type::argmin)
.def("argmax", &wrapped_type::argmax)
.def(
"argwhere",
[](wrapped_type const & self, std::function<bool(value_type const &)> const & condition)
{
if (!condition)
{
return py::cast(self.argwhere());
}
return py::cast(self.argwhere(condition));
},
py::arg("condition") = py::none())
.def(
"where",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy's where() says that when only condition is provided, nonzero() should be preferred. So I didn't give SimpleArray::where default operation.

[](wrapped_type const & self, std::function<bool(value_type const &)> const & condition, wrapped_type const & other)
{ return py::cast(self.where(condition, other)); })
//
;

return *this;
}
}; /* end class WrapSimpleArray */
Expand Down
79 changes: 79 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,85 @@ def test_argminmax(self):
self.assertEqual(narr.argmin(), sarr.argmin())
self.assertEqual(narr.argmax(), sarr.argmax())

def test_argwhere(self):
# test 1-D data
data = [1, 3, 5, 7, 9]
narr = np.array(data, dtype='uint64')
sarr = modmesh.SimpleArrayUint64(array=narr)

ret_np = np.argwhere(narr > 5)
ret_sa = sarr.argwhere(lambda x: x > 5)

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(ret_sa.shape[0]):
for j in range(ret_sa.shape[1]):
self.assertEqual(ret_np[i, j], ret_sa[i, j])

# test N-D data
data = [[-1.3, -4.8, 1.5, 0.3, 7.1], [2.5, 4.8, -0.1, 9.4, 7.6]]
narr = np.array(data, dtype='float64')
sarr = modmesh.SimpleArrayFloat64(array=narr)

ret_np = np.argwhere(narr <= 2.8)
ret_sa = sarr.argwhere(lambda x: x <= 2.8)

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(ret_sa.shape[0]):
for j in range(ret_sa.shape[1]):
self.assertEqual(ret_np[i, j], ret_sa[i, j])

# test N-D data
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[1, 3], [5, 7]]]
narr = np.array(data, dtype='int64')
sarr = modmesh.SimpleArrayInt64(array=narr)

ret_np = np.argwhere(narr == 5)
ret_sa = sarr.argwhere(lambda x: x == 5)

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(ret_sa.shape[0]):
for j in range(ret_sa.shape[1]):
self.assertEqual(ret_np[i, j], ret_sa[i, j])

# default case: non-zero
ret_np = np.argwhere(narr)
ret_sa = sarr.argwhere()

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(ret_sa.shape[0]):
for j in range(ret_sa.shape[1]):
self.assertEqual(ret_np[i, j], ret_sa[i, j])

def test_where(self):
# test 1-D data
data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
narr_1 = np.array(data, dtype='uint64')
narr_2 = narr_1 * 10
sarr_1 = modmesh.SimpleArrayUint64(array=narr_1)
sarr_2 = modmesh.SimpleArrayUint64(array=narr_2)

ret_np = np.where(narr_1 < 5, narr_1, narr_2)
ret_sa = sarr_1.where(lambda x: x < 5, sarr_2)

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(len(ret_np)):
self.assertEqual(ret_np[i], ret_sa[i])

# test N-D data
data = [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10], [1, 10, 1, 10, 1]]
narr_1 = np.array(data, dtype='float64')
narr_2 = narr_1 * 10
sarr_1 = modmesh.SimpleArrayFloat64(array=narr_1)
sarr_2 = modmesh.SimpleArrayFloat64(array=narr_2)

ret_np = np.where(narr_1 == 1, narr_1, narr_2)
ret_sa = sarr_1.where(lambda x: x == 1, sarr_2)

self.assertEqual(ret_np.shape, ret_sa.shape)
for i in range(ret_np.shape[0]):
for j in range(ret_np.shape[1]):
self.assertEqual(ret_np[i, j], ret_sa[i, j])


class SimpleArrayPlexTC(unittest.TestCase):

Expand Down
Loading