Skip to content

Commit 5408438

Browse files
committed
Adding rank member to allow sfinae based on rank
1 parent de229b4 commit 5408438

File tree

9 files changed

+190
-4
lines changed

9 files changed

+190
-4
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ for details.
6969
missing
7070
histogram
7171
random
72+
sfinae
7273
file_loading
7374
build-options
7475
pitfall

docs/source/sfinae.rst

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
2+
3+
Distributed under the terms of the BSD 3-Clause License.
4+
5+
The full license is in the file LICENSE, distributed with this software.
6+
7+
.. _histogram:
8+
9+
SFINAE
10+
======
11+
12+
Rank overload
13+
-------------
14+
15+
All `xtensor`'s classes have a member ``rank`` that can be used
16+
to overload based on rank using *SFINAE*.
17+
Consider the following example:
18+
19+
.. code-block:: cpp
20+
21+
template <class E, std::enable_if_t<!xt::has_rank_t<E, 2>::value, int> = 0>
22+
inline E foo(E&& a)
23+
{
24+
... // act on object of flexible rank, or fixed rank != 2
25+
}
26+
27+
template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
28+
inline E foo(E&& a)
29+
{
30+
... // act on object of fixed rank == 2
31+
}
32+
33+
TEST(sfinae, rank_basic)
34+
{
35+
xt::xarray<size_t> a = {{9, 9}, {9, 9}};
36+
xt::xtensor<size_t, 1> b = {9, 9};
37+
xt::xtensor<size_t, 2> c = {{9, 9}, {9, 9}};
38+
39+
foo(a); // flexible rank -> first overload
40+
foo(b); // fixed rank == 2 -> first overload
41+
foo(c); // fixed rank == 2 -> second overload
42+
}
43+
44+
.. note::
45+
46+
If one wants to test for more than a single value for ``rank``,
47+
one can use the default value ``SIZE_MAX`` used for flexible rank objects.
48+
For example, one could have the following overloads:
49+
50+
.. code-block:: cpp
51+
52+
// flexible rank
53+
template <class E, std::enable_if_t<xt::has_rank_t<E, SIZE_MAX>::value, int> = 0>
54+
inline E foo(E&& a);
55+
56+
// fixed rank == 1
57+
template <class E, std::enable_if_t<xt::has_rank_t<E, 1>::value, int> = 0>
58+
inline E foo(E&& a);
59+
60+
// fixed rank == 2
61+
template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
62+
inline E foo(E&& a);
63+
64+
Note that fixed ranks other than 1 and 2 will raise a compiler error.
65+
66+
Of course, if one wants a more limited scope, one could also do the following:
67+
68+
.. code-block:: cpp
69+
70+
// flexible rank
71+
inline void foo(xt::xarray<double>& a);
72+
73+
// fixed rank == 1
74+
inline void foo(xt::xtensor<double,1>& a);
75+
76+
// fixed rank == 2
77+
inline void foo(xt::xtensor<double,2>& a);

include/xtensor/xarray.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ namespace xt
104104
using inner_backstrides_type = typename base_type::inner_backstrides_type;
105105
using temporary_type = typename semantic_base::temporary_type;
106106
using expression_tag = Tag;
107+
constexpr static std::size_t rank = SIZE_MAX;
107108

108109
xarray_container();
109110
explicit xarray_container(const shape_type& shape, layout_type l = L);

include/xtensor/xfixed.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace xt
5454
namespace detail
5555
{
5656
/**************************************************************************************
57-
The following is something we can currently only dream about -- for when we drop
57+
The following is something we can currently only dream about -- for when we drop
5858
support for a lot of the old compilers (e.g. GCC 4.9, MSVC 2017 ;)
5959
6060
template <class T>
@@ -281,7 +281,7 @@ namespace xt
281281
* with tensor semantic and fixed dimension
282282
*
283283
* @tparam ET The type of the elements.
284-
* @tparam S The xshape template paramter of the container.
284+
* @tparam S The xshape template paramter of the container.
285285
* @tparam L The layout_type of the tensor.
286286
* @tparam SH Wether the tensor can be used as a shared expression.
287287
* @tparam Tag The expression tag.
@@ -313,6 +313,7 @@ namespace xt
313313
using expression_tag = Tag;
314314

315315
constexpr static std::size_t N = std::tuple_size<shape_type>::value;
316+
constexpr static std::size_t rank = N;
316317

317318
xfixed_container() = default;
318319
xfixed_container(const value_type& v);
@@ -616,7 +617,7 @@ namespace xt
616617

617618
/**
618619
* Allocates an xfixed_container with shape S with values from a C array.
619-
* The type returned by get_init_type_t is raw C array ``value_type[X][Y][Z]`` for ``xt::xshape<X, Y, Z>``.
620+
* The type returned by get_init_type_t is raw C array ``value_type[X][Y][Z]`` for ``xt::xshape<X, Y, Z>``.
620621
* C arrays can be initialized with the initializer list syntax, but the size is checked at compile
621622
* time to prevent errors.
622623
* Note: for clang < 3.8 this is an initializer_list and the size is not checked at compile-or runtime.

include/xtensor/xtensor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ namespace xt
105105
using inner_strides_type = typename base_type::inner_strides_type;
106106
using temporary_type = typename semantic_base::temporary_type;
107107
using expression_tag = Tag;
108+
constexpr static std::size_t rank = N;
108109

109110
xtensor_container();
110111
xtensor_container(nested_initializer_list_t<value_type, N> t);
@@ -759,7 +760,7 @@ namespace xt
759760
std::fill(m_storage.begin(), m_storage.end(), e);
760761
return *this;
761762
}
762-
763+
763764
template <class EC, std::size_t N, layout_type L, class Tag>
764765
inline auto xtensor_view<EC, N, L, Tag>::storage_impl() noexcept -> storage_type&
765766
{

include/xtensor/xutils.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,19 @@ namespace xt
857857

858858
template <class ST>
859859
using inner_reference_t = typename inner_reference<ST>::type;
860+
861+
/************
862+
* has_rank *
863+
************/
864+
865+
template <class E, size_t N>
866+
struct has_rank
867+
{
868+
using type = std::integral_constant<bool, std::decay_t<E>::rank == N>;
869+
};
870+
871+
template <class E, size_t N>
872+
using has_rank_t = typename has_rank<std::decay_t<E>, N>::type;
860873
}
861874

862875
#endif

include/xtensor/xview.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ namespace xt
432432

433433
using container_iterator = pointer;
434434
using const_container_iterator = const_pointer;
435+
constexpr static std::size_t rank = SIZE_MAX;
435436

436437
// The FSL argument prevents the compiler from calling this constructor
437438
// instead of the copy constructor when sizeof...(SL) == 0.

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ set(XTENSOR_TESTS
224224
test_extended_xhistogram.cpp
225225
test_extended_xsort.cpp
226226
test_xchunked_array.cpp
227+
test_sfinae.cpp
227228
)
228229

229230
if(nlohmann_json_FOUND)

test/test_sfinae.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#include <complex>
11+
#include <limits>
12+
13+
#include "gtest/gtest.h"
14+
#include "xtensor/xtensor.hpp"
15+
#include "xtensor/xarray.hpp"
16+
// #include "xtensor/xfixed.hpp"
17+
#include "xtensor/xview.hpp"
18+
19+
namespace xt
20+
{
21+
template <class E, std::enable_if_t<!xt::has_rank_t<E, 2>::value, int> = 0>
22+
inline E sfinae_rank_basic_func(E&& a)
23+
{
24+
E b = a;
25+
b.fill(0);
26+
return b;
27+
}
28+
29+
template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
30+
inline E sfinae_rank_basic_func(E&& a)
31+
{
32+
E b = a;
33+
b.fill(2);
34+
return b;
35+
}
36+
37+
TEST(sfinae, rank_basic)
38+
{
39+
xt::xarray<size_t> a = {{9, 9, 9}, {9, 9, 9}};
40+
xt::xtensor<size_t, 1> b = {9, 9};
41+
xt::xtensor<size_t, 2> c = {{9, 9}, {9, 9}};
42+
// xt::xtensor_fixed<size_t, xt::xshape<2, 2>> d = {{9, 9}, {9, 9}};
43+
auto v = xt::view(c, 0, xt::all());
44+
45+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_basic_func(a), 0ul)));
46+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_basic_func(b), 0ul)));
47+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_basic_func(c), 2ul)));
48+
// EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_basic_func(d), 2ul)));
49+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_basic_func(v), 0ul)));
50+
}
51+
52+
template <class E, std::enable_if_t<xt::has_rank_t<E, SIZE_MAX>::value, int> = 0>
53+
inline E sfinae_rank_func(E&& a)
54+
{
55+
E b = a;
56+
b.fill(0);
57+
return b;
58+
}
59+
60+
template <class E, std::enable_if_t<xt::has_rank_t<E, 1>::value, int> = 0>
61+
inline E sfinae_rank_func(E&& a)
62+
{
63+
E b = a;
64+
b.fill(1);
65+
return b;
66+
}
67+
68+
template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
69+
inline E sfinae_rank_func(E&& a)
70+
{
71+
E b = a;
72+
b.fill(2);
73+
return b;
74+
}
75+
76+
TEST(sfinae, rank)
77+
{
78+
xt::xarray<size_t> a = {{9, 9, 9}, {9, 9, 9}};
79+
xt::xtensor<size_t, 1> b = {9, 9};
80+
xt::xtensor<size_t, 2> c = {{9, 9}, {9, 9}};
81+
// xt::xtensor_fixed<size_t, xt::xshape<2, 2>> d = {{9, 9}, {9, 9}};
82+
auto v = xt::view(c, 0, xt::all());
83+
84+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_func(a), 0ul)));
85+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_func(b), 1ul)));
86+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_func(c), 2ul)));
87+
// EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_func(d), 2ul)));
88+
EXPECT_TRUE(xt::all(xt::equal(sfinae_rank_func(v), 0ul)));
89+
}
90+
}

0 commit comments

Comments
 (0)