diff --git a/CHANGELOG.md b/CHANGELOG.md index eb31d3cd325..9ef9ba86f2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `timeout-minutes` property to GitHub jobs [#2526](https://github.com/IntelPython/dpnp/pull/2526) * Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521) * Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534) +* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550) ### Changed diff --git a/dpnp/CMakeLists.txt b/dpnp/CMakeLists.txt index 6be90d849dc..80b2552ea58 100644 --- a/dpnp/CMakeLists.txt +++ b/dpnp/CMakeLists.txt @@ -58,6 +58,7 @@ endfunction() add_subdirectory(backend) add_subdirectory(backend/extensions/blas) add_subdirectory(backend/extensions/fft) +add_subdirectory(backend/extensions/functional) add_subdirectory(backend/extensions/indexing) add_subdirectory(backend/extensions/lapack) add_subdirectory(backend/extensions/statistics) diff --git a/dpnp/backend/extensions/functional/CMakeLists.txt b/dpnp/backend/extensions/functional/CMakeLists.txt new file mode 100644 index 00000000000..f248bb95f09 --- /dev/null +++ b/dpnp/backend/extensions/functional/CMakeLists.txt @@ -0,0 +1,90 @@ +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + + +set(python_module_name _functional_impl) +set(_module_src + ${CMAKE_CURRENT_SOURCE_DIR}/piecewise.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/functional_py.cpp +) + +pybind11_add_module(${python_module_name} MODULE ${_module_src}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) + +if(_dpnp_sycl_targets) + # make fat binary + target_compile_options( + ${python_module_name} + PRIVATE + ${_dpnp_sycl_target_compile_options} + ) + target_link_options( + ${python_module_name} + PRIVATE + ${_dpnp_sycl_target_link_options} + ) +endif() + +if (WIN32) + if (${CMAKE_VERSION} VERSION_LESS "3.27") + # this is a work-around for target_link_options inserting option after -link option, cause + # linker to ignore it. + set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel") + endif() +endif() + +set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON) + +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) + +target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR}) +target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) + +if (WIN32) + target_compile_options(${python_module_name} PRIVATE + /clang:-fno-approx-func + /clang:-fno-finite-math-only + ) +else() + target_compile_options(${python_module_name} PRIVATE + -fno-approx-func + -fno-finite-math-only + ) +endif() + +target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel) + +if (DPNP_GENERATE_COVERAGE) + target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping) +endif() + +if (DPNP_WITH_REDIST) + set_target_properties(${python_module_name} PROPERTIES INSTALL_RPATH "$ORIGIN/../../../../../../") +endif() + +install(TARGETS ${python_module_name} + DESTINATION "dpnp/backend/extensions/functional" +) diff --git a/dpnp/backend/extensions/functional/functional_py.cpp b/dpnp/backend/extensions/functional/functional_py.cpp new file mode 100644 index 00000000000..0ba9f0b2a94 --- /dev/null +++ b/dpnp/backend/extensions/functional/functional_py.cpp @@ -0,0 +1,48 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +// This file defines functions of dpnp.backend._functional_impl extensions +// +//***************************************************************************** + +#include +#include + +#include "piecewise.hpp" + +namespace functional_ns = dpnp::extensions::functional; +namespace py = pybind11; + +PYBIND11_MODULE(_functional_impl, m) +{ + { + functional_ns::init_piecewise_dispatch_vectors(); + + m.def("_piecewise", functional_ns::py_piecewise, + "Call piecewise kernel", py::arg("sycl_queue"), py::arg("value"), + py::arg("condition"), py::arg("result"), + py::arg("depends") = py::list()); + } +} diff --git a/dpnp/backend/extensions/functional/piecewise.cpp b/dpnp/backend/extensions/functional/piecewise.cpp new file mode 100644 index 00000000000..81c9796a2ce --- /dev/null +++ b/dpnp/backend/extensions/functional/piecewise.cpp @@ -0,0 +1,215 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include "piecewise.hpp" + +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include +#include + +namespace dpnp::extensions::functional +{ +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; + +typedef sycl::event (*piecewise_fn_ptr_t)(sycl::queue &, + const py::object &, + const std::size_t, + const char *, + char *, + const std::vector &); + +static piecewise_fn_ptr_t piecewise_dispatch_vector[dpctl_td_ns::num_types]; + +template +class PiecewiseFunctor +{ +private: + const T val; + const bool *cond = nullptr; + T *res = nullptr; + +public: + PiecewiseFunctor(const T val, const bool *cond, T *res) + : val(val), cond(cond), res(res) + { + } + + void operator()(sycl::id<1> id) const + { + const auto i = id.get(0); + if (cond[i]) { + res[i] = val; + } + } +}; + +template +sycl::event piecewise_impl(sycl::queue &exec_q, + const py::object &value, + const std::size_t nelems, + const char *condition, + char *result, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(exec_q); + + py::object type_obj = py::type::of(value); + std::string type_name = py::str(type_obj.attr("__name__")); + + T *res = reinterpret_cast(result); + const bool *cond = reinterpret_cast(condition); + T val = py::cast(value); + + sycl::event piecewise_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using PiecewiseKernel = PiecewiseFunctor; + cgh.parallel_for(sycl::range<1>(nelems), + PiecewiseKernel(val, cond, res)); + }); + + return piecewise_ev; +} + +/** + * @brief A factory to define pairs of supported types for which + * piecewise function is available. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct PiecewiseOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + +template +struct PiecewiseFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename PiecewiseOutputType::value_type, void>) { + return nullptr; + } + else { + return piecewise_impl; + } + } +}; + +std::pair + py_piecewise(sycl::queue &exec_q, + const py::object &value, + const dpctl::tensor::usm_ndarray &condition, + const dpctl::tensor::usm_ndarray &result, + const std::vector &depends) +{ + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result); + + const int res_nd = result.get_ndim(); + const int cond_nd = condition.get_ndim(); + if (res_nd != cond_nd) { + throw py::value_error( + "Condition and result arrays must have the same dimension."); + } + + if (!dpctl::utils::queues_are_compatible( + exec_q, {condition.get_queue(), result.get_queue()})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queue."); + } + + const bool is_result_c_contig = result.is_c_contiguous(); + if (!is_result_c_contig) { + throw py::value_error("The result array is not c-contiguous."); + } + + const py::ssize_t *res_shape = result.get_shape_raw(); + const py::ssize_t *cond_shape = condition.get_shape_raw(); + + const bool shapes_equal = + std::equal(res_shape, res_shape + res_nd, cond_shape); + if (!shapes_equal) { + throw py::value_error( + "Condition and result arrays must have the same shape."); + } + + const std::size_t nelems = result.get_size(); + if (nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); + } + + const int result_typenum = result.get_typenum(); + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int result_type_id = array_types.typenum_to_lookup_id(result_typenum); + auto piecewise_fn = piecewise_dispatch_vector[result_type_id]; + + if (piecewise_fn == nullptr) { + throw std::runtime_error("Type of given array is not supported"); + } + + const char *condition_typeless_ptr = condition.get_data(); + char *result_typeless_ptr = result.get_data(); + + sycl::event piecewise_ev = + piecewise_fn(exec_q, value, nelems, condition_typeless_ptr, + result_typeless_ptr, depends); + sycl::event args_ev = + dpctl::utils::keep_args_alive(exec_q, {result}, {piecewise_ev}); + + return std::make_pair(args_ev, piecewise_ev); +} + +void init_piecewise_dispatch_vectors(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(piecewise_dispatch_vector); + + return; +} + +} // namespace dpnp::extensions::functional diff --git a/dpnp/backend/extensions/functional/piecewise.hpp b/dpnp/backend/extensions/functional/piecewise.hpp new file mode 100644 index 00000000000..efe09ca85d5 --- /dev/null +++ b/dpnp/backend/extensions/functional/piecewise.hpp @@ -0,0 +1,42 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +namespace dpnp::extensions::functional +{ +extern std::pair + py_piecewise(sycl::queue &exec_q, + const py::object &value, + const dpctl::tensor::usm_ndarray &condition, + const dpctl::tensor::usm_ndarray &result, + const std::vector &depends); + +extern void init_piecewise_dispatch_vectors(void); + +} // namespace dpnp::extensions::functional diff --git a/dpnp/dpnp_iface_functional.py b/dpnp/dpnp_iface_functional.py index e6fd9883d81..074d95100f3 100644 --- a/dpnp/dpnp_iface_functional.py +++ b/dpnp/dpnp_iface_functional.py @@ -36,15 +36,22 @@ """ +# pylint: disable=no-name-in-module +# pylint: disable=protected-access +import dpctl.utils as dpu from dpctl.tensor._numpy_helper import ( normalize_axis_index, normalize_axis_tuple, ) import dpnp +import dpnp.backend.extensions.functional._functional_impl as fi -__all__ = ["apply_along_axis", "apply_over_axes"] +# pylint: disable=no-name-in-module +from dpnp.dpnp_utils import get_usm_allocations + +__all__ = ["apply_along_axis", "apply_over_axes", "piecewise"] def apply_along_axis(func1d, axis, arr, *args, **kwargs): @@ -266,3 +273,149 @@ def apply_over_axes(func, a, axes): ) a = res return res + + +def piecewise(x, condlist, funclist): + """ + Evaluate a piecewise-defined function. + + Given a set of conditions and corresponding functions, evaluate each + function on the input data wherever its condition is true. + + For full documentation refer to :obj:`numpy.piecewise`. + + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + The input domain. + condlist : {list of array-like boolean, bool scalars} + Each boolean array/scalar corresponds to a function in `funclist`. + Wherever `condlist[i]` is ``True``, `funclist[i](x)` is used as the + output value. + + Each boolean array in `condlist` selects a piece of `x`, and should + therefore be of the same shape as `x`. + + The length of `condlist` must correspond to that of `funclist`. + If one extra function is given, i.e. if + ``len(funclist) == len(condlist) + 1``, then that extra function + is the default value, used wherever all conditions are ``False``. + funclist : {array-like of scalars} + A constant value is returned wherever corresponding condition of `x` + is ``True``. + + Returns + ------- + out : dpnp.ndarray + The output is the same shape and type as `x` and is found by + calling the functions in `funclist` on the appropriate portions of `x`, + as defined by the boolean arrays in `condlist`. Portions not covered + by any condition have a default value of ``0``. + + Limitations + ----------- + Parameters `args` and `kw` are not supported and `funclist` cannot include a + callable functions. + + See Also + -------- + :obj:`dpnp.choose` : Construct an array from an index array and a set of + arrays to choose from. + :obj:`dpnp.select` : Return an array drawn from elements in `choicelist`, + depending on conditions. + :obj:`dpnp.where` : Return elements from one of two arrays depending + on condition. + + Examples + -------- + >>> import dpnp as np + + Define the signum function, which is -1 for ``x < 0`` and +1 for ``x >= 0``. + + >>> x = np.linspace(-2.5, 2.5, 6) + >>> np.piecewise(x, [x < 0, x >= 0], [-1, 1]) + array([-1., -1., -1., 1., 1., 1.]) + + """ + dpnp.check_supported_arrays_type(x) + if isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]: + condlist = [condlist] + elif dpnp.isscalar(condlist) or ( + dpnp.isscalar(condlist[0]) and x.ndim != 0 + ): + # convert scalar to a list of one array + # convert list of scalars to a list of one array + condlist = [ + dpnp.full( + x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue + ) + ] + elif not isinstance(condlist[0], (dpnp.ndarray)): + # convert list of lists to list of arrays + # convert list of scalars to a list of 0d arrays (for 0d input) + tmp = [] + for _, cond in enumerate(condlist): + tmp.append( + dpnp.array(cond, usm_type=x.usm_type, sycl_queue=x.sycl_queue) + ) + condlist = tmp + + dpnp.check_supported_arrays_type(*condlist) + if dpnp.is_supported_array_type(funclist): + usm_type, exec_q = get_usm_allocations([x, *condlist, funclist]) + else: + usm_type, exec_q = get_usm_allocations([x, *condlist]) + + condlen = len(condlist) + try: + if isinstance(funclist, str): + raise TypeError + funclen = len(funclist) + except TypeError as e: + raise TypeError("funclist must be a sequence of scalars") from e + if condlen == funclen: + # default value is zero + result = dpnp.zeros_like(x, usm_type=usm_type, sycl_queue=exec_q) + elif condlen + 1 == funclen: + # default value is the last element of funclist + func = funclist[-1] + funclist = funclist[:-1] + if callable(func): + raise NotImplementedError( + "Callable functions are not supported currently" + ) + result = dpnp.full( + x.shape, func, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q + ) + else: + raise ValueError( + f"with {condlen} condition(s), either {condlen} or {condlen + 1} " + "functions are expected" + ) + + for condition, func in zip(condlist, funclist): + if callable(func): + raise NotImplementedError( + "Callable functions are not supported currently" + ) + if isinstance(func, dpnp.ndarray): + func = func.astype(x.dtype) + else: + func = x.dtype.type(func) + + # TODO: possibly can use func.item() to make sure that func is always + # a scalar and simplify the backend but current implementation of + # ndarray.item() copies to host memory and it is not efficient for + # large arrays + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + ht_ev, fun_ev = fi._piecewise( + exec_q, + func, # it is a scalar or 0d array + dpnp.get_usm_ndarray(condition), + dpnp.get_usm_ndarray(result), + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, fun_ev) + + return result diff --git a/dpnp/tests/test_functional.py b/dpnp/tests/test_functional.py index 14c7e086cd5..295cee4f0c8 100644 --- a/dpnp/tests/test_functional.py +++ b/dpnp/tests/test_functional.py @@ -1,10 +1,20 @@ import numpy import pytest -from numpy.testing import assert_array_equal, assert_raises +from numpy.testing import ( + assert_array_equal, + assert_equal, + assert_raises, + assert_raises_regex, +) import dpnp -from .helper import get_all_dtypes +from .helper import ( + assert_dtype_allclose, + generate_random_numpy_array, + get_all_dtypes, + get_unsigned_dtypes, +) class TestApplyAlongAxis: @@ -65,3 +75,244 @@ def custom_func(x, axis): ia = dpnp.arange(24).reshape(2, 3, 4) assert_raises(ValueError, dpnp.apply_over_axes, custom_func, ia, 1) + + +class TestPiecewise: + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_unsigned=True) + ) + @pytest.mark.parametrize("funclist", [[True, False], [-1, 1], [-1.5, 1.5]]) + def test_basic(self, dtype, funclist): + low = 0 if dpnp.issubdtype(dtype, dpnp.unsignedinteger) else -10 + a = generate_random_numpy_array(10, dtype=dtype, low=low) + ia = dpnp.array(a) + + expected = numpy.piecewise(a, [a < 0, a >= 0], funclist) + result = dpnp.piecewise(ia, [ia < 0, ia >= 0], funclist) + assert a.dtype == result.dtype + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_unsigned_dtypes()) + @pytest.mark.parametrize("funclist", [[True, False], [1, 2], [1.5, 4.5]]) + def test_unsigned(self, dtype, funclist): + a = generate_random_numpy_array(10, dtype=dtype, low=0) + ia = dpnp.array(a) + + expected = numpy.piecewise(a, [a < 0, a >= 0], funclist) + result = dpnp.piecewise(ia, [ia < 0, ia >= 0], funclist) + assert a.dtype == result.dtype + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + def test_basic_complex(self, dtype): + a = generate_random_numpy_array(10, dtype=dtype) + ia = dpnp.array(a) + funclist = [-1.5 - 1j * 1.5, 1.5 + 1j * 1.5] + + if numpy.issubdtype(dtype, numpy.complexfloating) or dtype == dpnp.bool: + expected = numpy.piecewise(a, [a < 0, a >= 0], funclist) + result = dpnp.piecewise(ia, [ia < 0, ia >= 0], funclist) + assert a.dtype == result.dtype + assert_dtype_allclose(result, expected) + else: + # If dtype is not complex, piecewise should raise an error + pytest.raises( + TypeError, numpy.piecewise, a, [a < 0, a >= 0], funclist + ) + pytest.raises( + TypeError, dpnp.piecewise, ia, [ia < 0, ia >= 0], funclist + ) + + def test_simple(self): + a = numpy.array([0, 0]) + ia = dpnp.array(a) + # Condition is single bool list + expected = numpy.piecewise(a, [True, False], [1]) + result = dpnp.piecewise(ia, [True, False], [1]) + assert_array_equal(result, expected) + + # List of conditions: single bool list + expected = numpy.piecewise(a, [[True, False]], [1]) + result = dpnp.piecewise(ia, [[True, False]], [1]) + assert_array_equal(result, expected) + + # Conditions is single bool array + expected = numpy.piecewise(a, [numpy.array([True, False])], [1]) + result = dpnp.piecewise(ia, dpnp.array([True, False]), [1]) + assert_array_equal(result, expected) + + # Condition is single int array + expected = numpy.piecewise(a, [numpy.array([1, 0])], [1]) + result = dpnp.piecewise(ia, dpnp.array([1, 0]), [1]) + assert_array_equal(result, expected) + + # List of conditions: int array + expected = numpy.piecewise(a, [numpy.array([1, 0])], [1]) + result = dpnp.piecewise(ia, [dpnp.array([1, 0])], [1]) + assert_array_equal(result, expected) + + # List of conditions: single bool tuple + expected = numpy.piecewise(a, ([True, False], [False, True]), [1, -4]) + result = dpnp.piecewise(ia, ([True, False], [False, True]), [1, -4]) + assert_array_equal(result, expected) + + # Condition is single bool tuple + expected = numpy.piecewise(a, (True, False), [1]) + result = dpnp.piecewise(ia, (True, False), [1]) + assert_array_equal(result, expected) + + def test_error_dpnp(self): + ia = dpnp.array([0, 0]) + # values cannot be a callable function + assert_raises_regex( + NotImplementedError, + "Callable functions are not supported currently", + dpnp.piecewise, + ia, + [dpnp.array([True, False])], + [lambda x: -1], + ) + + # default value cannot be a callable function + assert_raises_regex( + NotImplementedError, + "Callable functions are not supported currently", + dpnp.piecewise, + ia, + [dpnp.array([True, False])], + [-1, lambda x: 1], + ) + + # funclist is not array-like + assert_raises_regex( + TypeError, + "funclist must be a sequence of scalars", + dpnp.piecewise, + ia, + [dpnp.array([True, False])], + 1, + ) + + assert_raises_regex( + TypeError, + "object of type", + numpy.piecewise, + ia.asnumpy(), + [numpy.array([True, False])], + 1, + ) + + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_error(self, xp): + ia = xp.array([0, 0]) + # not enough functions + assert_raises_regex( + ValueError, + "1 or 2 functions are expected", + xp.piecewise, + ia, + [xp.array([True, False])], + [], + ) + + # extra function + assert_raises_regex( + ValueError, + "1 or 2 functions are expected", + xp.piecewise, + ia, + [xp.array([True, False])], + [1, 2, 3], + ) + + def test_two_conditions(self): + a = numpy.array([1, 2]) + ia = dpnp.array(a) + cond = numpy.array([True, False]) + icond = dpnp.array(cond) + expected = numpy.piecewise(a, [cond, cond], [3, 4]) + result = dpnp.piecewise(ia, [icond, icond], [3, 4]) + assert_array_equal(result, expected) + + def test_default(self): + a = numpy.array([1, 2]) + ia = dpnp.array(a) + # No value specified for x[1], should be 0 + expected = numpy.piecewise(a, [True, False], [2]) + result = dpnp.piecewise(ia, [True, False], [2]) + assert_array_equal(result, expected) + + # Should set x[1] to 3 + expected = numpy.piecewise(a, [True, False], [2, 3]) + result = dpnp.piecewise(ia, [True, False], [2, 3]) + assert_array_equal(result, expected) + + def test_0d(self): + a = numpy.array(3) + ia = dpnp.array(a) + + expected = numpy.piecewise(a, a > 3, [4, 0]) + result = dpnp.piecewise(ia, ia > 3, [4, 0]) + assert_array_equal(result, expected) + + a = numpy.array(5) + ia = dpnp.array(a) + expected = numpy.piecewise(a, [True, False], [1, 0]) + result = dpnp.piecewise(ia, [True, False], [1, 0]) + assert_array_equal(result, expected) + + expected = numpy.piecewise(a, [False, False, True], [1, 2, 3]) + result = dpnp.piecewise(ia, [False, False, True], [1, 2, 3]) + assert_array_equal(result, expected) + + def test_0d_comparison(self): + a = numpy.array(3) + ia = dpnp.array(a) + expected = numpy.piecewise(a, [a > 3, a <= 3], [4, 0]) + result = dpnp.piecewise(ia, [ia > 3, ia <= 3], [4, 0]) + assert_array_equal(result, expected) + + a = numpy.array(4) + ia = dpnp.array(a) + expected = numpy.piecewise( + a, [a <= 3, (a > 3) * (a <= 5), a > 5], [1, 2, 3] + ) + result = dpnp.piecewise( + ia, [ia <= 3, (ia > 3) * (ia <= 5), ia > 5], [1, 2, 3] + ) + assert_array_equal(result, expected) + + assert_raises_regex( + ValueError, + "2 or 3 functions are expected", + dpnp.piecewise, + ia, + [ia <= 3, ia > 3], + [1], + ) + assert_raises_regex( + ValueError, + "2 or 3 functions are expected", + dpnp.piecewise, + ia, + [ia <= 3, ia > 3], + [1, 1, 1, 1], + ) + + def test_0d_0d_condition(self): + a = numpy.array(3) + ia = dpnp.array(a) + c = numpy.array(a > 3) + ic = dpnp.array(ia > 3) + + expected = numpy.piecewise(a, [c], [1, 2]) + result = dpnp.piecewise(ia, [ic], [1, 2]) + assert_equal(result, expected) + + def test_multidimensional_extrafunc(self): + a = numpy.array([[-2.5, -1.5, -0.5], [0.5, 1.5, 2.5]]) + ia = dpnp.array(a) + + expected = numpy.piecewise(a, [a < 0, a >= 2], [-1, 1, 3]) + result = dpnp.piecewise(ia, [ia < 0, ia >= 2], [-1, 1, 3]) + assert_array_equal(result, expected) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 31dfb74f2cf..493aba7b225 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1179,6 +1179,19 @@ def test_apply_over_axes(device): assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) +@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +def test_piecewise(device): + x = dpnp.array([0, 0], device=device) + y = dpnp.array([True, False], device=device) + z = dpnp.array([1, -1], device=device) + result = dpnp.piecewise(x, y, z) + res_sycl_queue = result.sycl_queue + + assert_sycl_queue_equal(res_sycl_queue, x.sycl_queue) + assert_sycl_queue_equal(res_sycl_queue, y.sycl_queue) + assert_sycl_queue_equal(res_sycl_queue, z.sycl_queue) + + @pytest.mark.parametrize("device_x", valid_dev, ids=dev_ids) @pytest.mark.parametrize("device_y", valid_dev, ids=dev_ids) def test_asarray(device_x, device_y): diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index aed316eca53..bf494bb3e0f 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -755,6 +755,19 @@ def test_apply_over_axes(usm_type): assert x.usm_type == y.usm_type +@pytest.mark.parametrize("usm_type", list_of_usm_types) +def test_piecewise(usm_type): + x = dpnp.array([0, 0], usm_type=usm_type) + y = dpnp.array([True, False], usm_type=usm_type) + z = dpnp.array([1, -1], usm_type=usm_type) + result = dpnp.piecewise(x, y, z) + res_usm_type = result.usm_type + + assert x.usm_type == res_usm_type + assert y.usm_type == res_usm_type + assert z.usm_type == res_usm_type + + @pytest.mark.parametrize( "func,data1,data2", [ diff --git a/dpnp/tests/third_party/cupy/functional_tests/init.py b/dpnp/tests/third_party/cupy/functional_tests/init.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dpnp/tests/third_party/cupy/functional_tests/test_piecewise.py b/dpnp/tests/third_party/cupy/functional_tests/test_piecewise.py new file mode 100644 index 00000000000..d378fb909bb --- /dev/null +++ b/dpnp/tests/third_party/cupy/functional_tests/test_piecewise.py @@ -0,0 +1,121 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from dpnp.tests.third_party.cupy import testing + + +class TestPiecewise(unittest.TestCase): + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise(self, xp, dtype): + x = xp.linspace(2.5, 12.5, 6, dtype=dtype) + condlist = [x < 0, x >= 0, x < 5, x >= 1.5] + funclist = xp.array([-1, 1, 2, 5]) + return xp.piecewise(x, condlist, funclist) + + @pytest.mark.skip("scalar input is not supported") + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_scalar_input(self, xp, dtype): + x = dtype(2) + condlist = [x < 0, x >= 0] + funclist = [1, 10] + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_scalar_condition(self, xp, dtype): + x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype) + condlist = True + funclist = xp.array([-10, 10]) + return xp.piecewise(x, condlist, funclist) + + @testing.for_signed_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_otherwise_condition1(self, xp, dtype): + x = xp.linspace(-2, 20, 12, dtype=dtype) + condlist = [x > 15, x <= 5, x == 0, x == 10] + funclist = xp.array([-1, 0, 2, 3, -5]) + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_otherwise_condition2(self, xp, dtype): + x = xp.array([-10, 20, 30, 40]).astype(dtype) + condlist = [ + xp.array([True, False, False, True]), + xp.array([True, False, False, True]), + ] + funclist = xp.array([-1, 1, 2]) + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_zero_dim_input(self, xp, dtype): + x = testing.shaped_random(shape=(), xp=xp, dtype=dtype) + condlist = [x < 0, x > 0] + funclist = [10, 1, 2] + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_ndim_input(self, xp, dtype): + x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype) + condlist = [x < 0, x > 0] + funclist = [10, 1, 2] + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_zero_dim_condlist(self, xp, dtype): + x = testing.shaped_random(shape=(), xp=xp, dtype=dtype) + condlist = [testing.shaped_random(shape=(), xp=xp, dtype=bool)] + funclist = [1, 2] + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_piecewise_ndarray_condlist_funclist(self, xp, dtype): + x = xp.linspace(1, 20, 12, dtype=dtype) + condlist = xp.array([x > 15, x <= 5, x == 0, x == 10]) + funclist = xp.array([-1, 0, 2, 3, -5]).astype(dtype) + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_complex=True + ) + @testing.numpy_cupy_array_equal() + def test_piecewise_diff_types_funclist(self, xp, dtype1, dtype2): + x = xp.linspace(1, 20, 12, dtype=dtype1) + condlist = [x > 15, x <= 5, x == 0, x == 10] + funclist = xp.array([1, 0, 2, 3, 5], dtype=dtype2) + return xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + def test_mismatched_lengths(self, dtype): + funclist = [-1, 0, 2, 4, 5] + for xp in (numpy, cupy): + x = xp.linspace(-2, 4, 6, dtype=dtype) + condlist = [x < 0, x >= 0] + with pytest.raises(ValueError): + xp.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + def test_callable_funclist(self, dtype): + x = cupy.linspace(-2, 4, 6, dtype=dtype) + condlist = [x < 0, x > 0] + funclist = [lambda x: -x, lambda x: x] + with pytest.raises(NotImplementedError): + cupy.piecewise(x, condlist, funclist) + + @testing.for_all_dtypes() + def test_mixed_funclist(self, dtype): + x = cupy.linspace(-2, 2, 6, dtype=dtype) + condlist = [x < 0, x == 0, x > 0] + funclist = [-10, lambda x: -x, 10, lambda x: x] + with pytest.raises(NotImplementedError): + cupy.piecewise(x, condlist, funclist) diff --git a/dpnp/tests/third_party/cupy/functional_tests/test_vectorize.py b/dpnp/tests/third_party/cupy/functional_tests/test_vectorize.py new file mode 100644 index 00000000000..910ab2dc0aa --- /dev/null +++ b/dpnp/tests/third_party/cupy/functional_tests/test_vectorize.py @@ -0,0 +1,677 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from dpnp.tests.third_party.cupy import testing + +# from cupy.cuda import runtime + +pytest.skip("dpnp.vectorize is not implemented", allow_module_level=True) + + +class TestVectorizeOps(unittest.TestCase): + + def _run(self, func, xp, dtypes): + f = xp.vectorize(func) + args = [ + testing.shaped_random((20, 30), xp, dtype, seed=seed) + for seed, dtype in enumerate(dtypes) + ] + return f(*args) + + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose(rtol={"default": 1e-6, numpy.float16: 1.5e-3}) + def test_vectorize_reciprocal(self, xp, dtype): + def my_reciprocal(x): + scalar = xp.dtype(dtype).type(10) + return xp.reciprocal(x + scalar) + + return self._run(my_reciprocal, xp, [dtype]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal() + def test_vectorize_add(self, xp, dtype1, dtype2): + def my_add(x, y): + return x + y + + return self._run(my_add, xp, [dtype1, dtype2]) + + @testing.for_dtypes("bhilqefdFD") + @testing.numpy_cupy_array_equal() + def test_vectorize_sub(self, xp, dtype): + def my_sub(x, y): + return x - y + + return self._run(my_sub, xp, [dtype, dtype]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_allclose(rtol=1e-6) + def test_vectorize_mul(self, xp, dtype1, dtype2): + def my_mul(x, y): + return x * y + + return self._run(my_mul, xp, [dtype1, dtype2]) + + @testing.for_dtypes("qQefdFD") + @testing.numpy_cupy_allclose(rtol=1e-5) + def test_vectorize_pow(self, xp, dtype): + def my_pow(x, y): + return x**y + + f = xp.vectorize(my_pow) + x1 = testing.shaped_random((20, 30), xp, dtype, seed=0) + x2 = testing.shaped_random((20, 30), xp, dtype, seed=1) + x1[x1 == 0] = 1 + return f(x1, x2) + + @testing.for_all_dtypes_combination( + names=("dtype1", "dtype2"), no_bool=True, no_complex=True + ) + @testing.numpy_cupy_allclose(rtol=1e-5) + def test_vectorize_minmax(self, xp, dtype1, dtype2): + def my_minmax(x, y): + return max(x, y) - min(x, y) + + f = xp.vectorize(my_minmax) + x1 = testing.shaped_random((20, 30), xp, dtype1, seed=0) + x2 = testing.shaped_random((20, 30), xp, dtype2, seed=1) + x1[x1 == 0] = 1 + return f(x1, x2) + + def run_div(self, func, xp, dtypes): + dtype1, dtype2 = dtypes + f = xp.vectorize(func) + x1 = testing.shaped_random((20, 30), xp, dtype1, seed=0) + x2 = testing.shaped_random((20, 30), xp, dtype2, seed=1) + x2[x2 == 0] = 1 + return f(x1, x2) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_allclose(rtol=1e-6) + @testing.with_requires("numpy>=1.23", "numpy!=1.24.0", "numpy!=1.24.1") + def test_vectorize_div(self, xp, dtype1, dtype2): + def my_div(x, y): + return x / y + + return self.run_div(my_div, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_allclose(accept_error=TypeError) + def test_vectorize_floor_div(self, xp, dtype1, dtype2): + def my_floor_div(x, y): + return x // y + + return self.run_div(my_floor_div, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_allclose(rtol=1e-6, atol=1e-6, accept_error=TypeError) + def test_vectorize_mod(self, xp, dtype1, dtype2): + def my_mod(x, y): + return x % y + + return self.run_div(my_mod, xp, [dtype1, dtype2]) + + @testing.for_dtypes("iIlLqQ") + @testing.numpy_cupy_array_equal() + def test_vectorize_lshift(self, xp, dtype): + def my_lshift(x, y): + return x << y + + return self._run(my_lshift, xp, [dtype, dtype]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_rshift(self, xp, dtype1, dtype2): + def my_lshift(x, y): + return x >> y + + return self._run(my_lshift, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_bit_or(self, xp, dtype1, dtype2): + def my_bit_or(x, y): + return x | y + + return self._run(my_bit_or, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_bit_and(self, xp, dtype1, dtype2): + def my_bit_and(x, y): + return x & y + + return self._run(my_bit_and, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_bit_xor(self, xp, dtype1, dtype2): + def my_bit_xor(x, y): + return x ^ y + + return self._run(my_bit_xor, xp, [dtype1, dtype2]) + + @testing.numpy_cupy_array_equal() + def test_vectorize_bit_invert(self, xp): + def my_bit_invert(x): + return ~x + + return self._run(my_bit_invert, xp, [numpy.int64]) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_logical_not(self, xp, dtype): + def my_logical_not(x): + return not x + + return self._run(my_logical_not, xp, [dtype]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_eq(self, xp, dtype1, dtype2): + def my_eq(x, y): + return x == y + + return self._run(my_eq, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_neq(self, xp, dtype1, dtype2): + def my_neq(x, y): + return x != y + + return self._run(my_neq, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_lt(self, xp, dtype1, dtype2): + def my_lt(x, y): + return x < y + + return self._run(my_lt, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_le(self, xp, dtype1, dtype2): + def my_le(x, y): + return x <= y + + return self._run(my_le, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_gt(self, xp, dtype1, dtype2): + def my_gt(x, y): + return x > y + + return self._run(my_gt, xp, [dtype1, dtype2]) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_ge(self, xp, dtype1, dtype2): + def my_ge(x, y): + return x >= y + + return self._run(my_ge, xp, [dtype1, dtype2]) + + @testing.for_dtypes("bhilqefdFD") + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_usub(self, xp, dtype): + def my_usub(x): + return -x + + return self._run(my_usub, xp, [dtype]) + + +class TestVectorizeExprs(unittest.TestCase): + + @testing.for_all_dtypes(name="cond_dtype", no_complex=True) + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_vectorize_ifexp(self, xp, dtype, cond_dtype): + def my_ifexp(c, x, y): + return x if c else y + + f = xp.vectorize(my_ifexp) + cond = testing.shaped_random((20, 30), xp, cond_dtype, seed=0) + x = testing.shaped_random((20, 30), xp, dtype, seed=1) + y = testing.shaped_random((20, 30), xp, dtype, seed=2) + return f(cond, x, y) + + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_vectorize_incr(self, xp, dtype): + def my_incr(x): + return x + 1 + + if dtype != xp.float64: + pytest.xfail("vectorize with scalars: no NEP 50") + + f = xp.vectorize(my_incr) + x = testing.shaped_random((20, 30), xp, dtype, seed=0) + return f(x) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal(accept_error=TypeError) + def test_vectorize_ufunc_call(self, xp, dtype): + def my_ufunc_add(x, y): + return xp.add(x, y) + + f = xp.vectorize(my_ufunc_add) + x = testing.shaped_random((20, 30), xp, dtype, seed=1) + y = testing.shaped_random((20, 30), xp, dtype, seed=2) + return f(x, y) + + @testing.with_requires("numpy>=1.25") + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_allclose( + rtol={numpy.float16: 1e3, "default": 1e-7}, accept_error=TypeError + ) + def test_vectorize_ufunc_call_dtype(self, xp, dtype1, dtype2): + def my_ufunc_add(x, y): + return xp.add(x, y, dtype=dtype2) + + f = xp.vectorize(my_ufunc_add) + x = testing.shaped_random((20, 30), xp, dtype1, seed=1) + y = testing.shaped_random((20, 30), xp, dtype1, seed=2) + return f(x, y) + + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2"), full=True) + @testing.numpy_cupy_array_equal( + accept_error=(TypeError, cupy.exceptions.ComplexWarning) + ) + def test_vectorize_typecast(self, xp, dtype1, dtype2): + typecast = xp.dtype(dtype2).type + + def my_typecast(x): + return typecast(x) + + f = xp.vectorize(my_typecast) + x = testing.shaped_random((20, 30), xp, dtype1, seed=1) + return f(x) + + +class TestVectorizeInstructions(unittest.TestCase): + + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_vectorize_assign_new(self, xp, dtype): + def my_assign(x): + y = x + x + return x + y + + f = xp.vectorize(my_assign) + x = testing.shaped_random((20, 30), xp, dtype, seed=1) + return f(x) + + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_vectorize_assign_update(self, xp, dtype): + def my_assign(x): + x = x + x + return x + x + + f = xp.vectorize(my_assign) + x = testing.shaped_random((20, 30), xp, dtype, seed=1) + return f(x) + + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_vectorize_augassign(self, xp, dtype): + def my_augassign(x): + x += x + return x + x + + f = xp.vectorize(my_augassign) + x = testing.shaped_random((20, 30), xp, dtype, seed=1) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_vectorize_const_assign(self, xp): + def my_typecast(x): + typecast = xp.dtype("f").type + return typecast(x) + + f = xp.vectorize(my_typecast) + x = testing.shaped_random((20, 30), xp, numpy.int32, seed=1) + return f(x) + + def test_vectorize_const_typeerror(self): + def my_invalid_type(x): + x = numpy.dtype("f").type + return x + + f = cupy.vectorize(my_invalid_type) + x = testing.shaped_random((20, 30), cupy, numpy.int32, seed=1) + with pytest.raises(TypeError): + f(x) + + def test_vectorize_const_non_toplevel(self): + def my_invalid_type(x): + if x == 3: + typecast = numpy.dtype("f").type + return x + + f = cupy.vectorize(my_invalid_type) + x = cupy.array([1, 2, 3, 4, 5]) + with pytest.raises(TypeError): + f(x) + + @testing.numpy_cupy_array_equal() + def test_vectorize_nonconst_for_value(self, xp): + def my_nonconst_result(x): + result = numpy.int32(0) + result = x + return result + + f = xp.vectorize(my_nonconst_result) + x = testing.shaped_random((20, 30), xp, numpy.int32, seed=1) + return f(x) + + +class TestVectorizeStmts(unittest.TestCase): + + @testing.numpy_cupy_array_equal() + def test_if(self, xp): + def func_if(x): + if x % 2 == 0: + y = x + else: + y = -x + return y + + f = xp.vectorize(func_if) + x = xp.array([1, 2, 3, 4, 5]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_if_no_orlese(self, xp): + def func_if(x): + y = 0 + if x % 2 == 0: + y = x + return y + + f = xp.vectorize(func_if) + x = xp.array([1, 2, 3, 4, 5]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_elif(self, xp): + def func_if(x): + y = 0 + if x % 2 == 0: + y = x + elif x % 3 == 0: + y = -x + return y + + f = xp.vectorize(func_if) + x = xp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_while(self, xp): + def func_while(x): + y = 0 + while x > 0: + y += x + x -= 1 + return y + + f = xp.vectorize(func_while) + x = xp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + return f(x) + + @testing.for_dtypes("qQ") + @testing.numpy_cupy_array_equal() + def test_for(self, xp, dtype): + def func_for(x): + y = 0 + for i in range(x): + y += i + return y + + f = xp.vectorize(func_for) + x = xp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_for_const_range(self, xp): + def func_for(x): + for i in range(3, 10): + x += i + return x + + f = xp.vectorize(func_for) + x = xp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_for_range_step(self, xp): + def func_for(x, y, z): + res = 0 + for i in range(x, y, z): + res += i * i + return res + + f = xp.vectorize(func_for) + start = xp.array([0, 1, 2, 3, 4, 5]) + stop = xp.array([-21, -23, -19, 17, 27, 24]) + step = xp.array([-3, -2, -1, 1, 2, 3]) + return f(start, stop, step) + + @testing.numpy_cupy_array_equal() + def test_for_update_counter(self, xp): + def func_for(x): + for i in range(10): + x += i + i += 1 + return x + + f = xp.vectorize(func_for) + x = xp.array([0, 1, 2, 3, 4]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_for_counter_after_loop(self, xp): + def func_for(x): + for i in range(10): + pass + return x + i + + f = xp.vectorize(func_for) + x = xp.array([0, 1, 2, 3, 4]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_for_compound_expression_param(self, xp): + def func_for(x, y): + res = 0 + for i in range(x * y): + res += i + return res + + f = xp.vectorize(func_for) + x = xp.array([0, 1, 2, 3, 4]) + return f(x, x) + + @testing.numpy_cupy_array_equal() + def test_for_update_loop_condition(self, xp): + def func_for(x): + res = 0 + for i in range(x): + res += i + x -= 1 + return res + + f = xp.vectorize(func_for) + x = xp.array([0, 1, 2, 3, 4]) + return f(x) + + @testing.numpy_cupy_array_equal() + def test_tuple(self, xp): + def func_tuple(x, y): + x, y = y, x + z = x, y + a, b = z + return a * a + b + + f = xp.vectorize(func_tuple) + x = xp.array([0, 1, 2, 3, 4]) + y = xp.array([5, 6, 7, 8, 9]) + return f(x, y) + + @testing.numpy_cupy_array_equal() + def test_tuple_pattern_match(self, xp): + def func_pattern_match(x, y): + x, y = y, x + z = x, y + (a, b), y = z, x + return a * a + b + y + + f = xp.vectorize(func_pattern_match) + x = xp.array([0, 1, 2, 3, 4]) + y = xp.array([5, 6, 7, 8, 9]) + return f(x, y) + + def test_tuple_pattern_match_type_error(self): + def func_pattern_match(x, y): + x, y = y, x + z = x, y + (a, b), z = z, x + return a * a + b + + f = cupy.vectorize(func_pattern_match) + x = cupy.array([0, 1, 2, 3, 4]) + y = cupy.array([5, 6, 7, 8, 9]) + with pytest.raises(TypeError, match="Data type mismatch of variable:"): + return f(x, y) + + @testing.numpy_cupy_array_equal() + def test_return_tuple(self, xp): + def func_tuple(x, y): + return x + y, x / y + + f = xp.vectorize(func_tuple) + x = xp.array([0, 1, 2, 3, 4]) + y = xp.array([5, 6, 7, 8, 9]) + return f(x, y) + + +class _MyClass: + + def __init__(self, x): + self.x = x + + +class TestVectorizeConstants(unittest.TestCase): + + @testing.numpy_cupy_array_equal() + def test_vectorize_const_value(self, xp): + + def my_func(x1, x2): + return x1 - x2 + const + + const = 8 + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, xp.int64, seed=1) + x2 = testing.shaped_random((20, 30), xp, xp.int64, seed=2) + return f(x1, x2) + + @testing.numpy_cupy_array_equal() + def test_vectorize_const_attr(self, xp): + + def my_func(x1, x2): + return x1 - x2 + const.x + + const = _MyClass(10) + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, xp.int64, seed=1) + x2 = testing.shaped_random((20, 30), xp, xp.int64, seed=2) + return f(x1, x2) + + +class TestVectorizeBroadcast(unittest.TestCase): + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_allclose(rtol=1e-5) + def test_vectorize_broadcast(self, xp, dtype): + def my_func(x1, x2): + return x1 + x2 + + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, dtype, seed=1) + x2 = testing.shaped_random((30,), xp, dtype, seed=2) + return f(x1, x2) + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_allclose(rtol=1e-5) + def test_vectorize_python_scalar_input(self, xp, dtype): + def my_func(x1, x2): + return x1 + x2 + + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, dtype, seed=1) + x2 = 1 + return f(x1, x2) + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_allclose(rtol=1e-5) + def test_vectorize_numpy_scalar_input(self, xp, dtype): + def my_func(x1, x2): + return x1 + x2 + + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, dtype, seed=1) + x2 = dtype(1) + return f(x1, x2) + + +class TestVectorize(unittest.TestCase): + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_allclose( + rtol={"default": 1e-5, numpy.float16: 1e-3 if runtime.is_hip else 1e-5} + ) + def test_vectorize_arithmetic_ops(self, xp, dtype): + def my_func(x1, x2, x3): + y = x1 + x2 * x3**x1 + x2 = y + x3 * x1 + return x1 + x2 + x3 + + f = xp.vectorize(my_func) + x1 = testing.shaped_random((20, 30), xp, dtype, seed=1, scale=4) + x2 = testing.shaped_random((20, 30), xp, dtype, seed=2, scale=4) + x3 = testing.shaped_random((20, 30), xp, dtype, seed=3, scale=4) + return f(x1, x2, x3) + + @testing.numpy_cupy_array_equal() + def test_vectorize_lambda(self, xp): + f = xp.vectorize(lambda a, b, c: a + b * c) + x1 = testing.shaped_random((20, 30), xp, numpy.int64, seed=1) + x2 = testing.shaped_random((20, 30), xp, numpy.int64, seed=2) + x3 = testing.shaped_random((20, 30), xp, numpy.int64, seed=3) + return f(x1, x2, x3) + + def test_vectorize_lambda_xfail(self): + functions = [lambda a, b: a + b, lambda a, b: a * b] + f = cupy.vectorize(functions[0]) + x1 = testing.shaped_random((20, 30), cupy, numpy.int64, seed=1) + x2 = testing.shaped_random((20, 30), cupy, numpy.int64, seed=2) + with pytest.raises(ValueError, match="Multiple callables are found"): + return f(x1, x2) + + @testing.numpy_cupy_array_equal() + def test_relu(self, xp): + f = xp.vectorize(lambda x: x if x > 0.0 else 0.0) + a = xp.array([0.4, -0.2, 1.8, -1.2], dtype=xp.float32) + return f(a) # float32 + + def test_relu_type_error(self): + f = cupy.vectorize(lambda x: x if x > 0.0 else cupy.float64(0.0)) + a = cupy.array([0.4, -0.2, 1.8, -1.2], dtype=cupy.float32) + with pytest.raises(TypeError): + return f(a)