55#include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
66#include <nbl/builtin/hlsl/rwmc/ResolveParameters.hlsl>
77#include <nbl/builtin/hlsl/concepts/accessors/loadable_image.hlsl>
8+ #include <nbl/builtin/hlsl/colorspace.hlsl>
9+ #include <nbl/builtin/hlsl/vector_utils/vector_traits.hlsl>
810
911namespace nbl
1012{
@@ -19,23 +21,21 @@ namespace rwmc
1921// not the greatest syntax but works
2022#define NBL_CONCEPT_PARAM_0 (a,T)
2123#define NBL_CONCEPT_PARAM_1 (scalar,VectorScalarType)
22- #define NBL_CONCEPT_PARAM_2 (vec,vector <VectorScalarType, Dims>)
2324// start concept
2425 NBL_CONCEPT_BEGIN (2 )
2526// need to be defined AFTER the concept begins
2627#define a NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
2728#define scalar NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
28- #define vec NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
2929NBL_CONCEPT_END (
30- ((NBL_CONCEPT_REQ_EXPR)((a.calcLuma (vec ))))
30+ ((NBL_CONCEPT_REQ_EXPR)((a.calcLuma (vector <VectorScalarType, 3 >(scalar, scalar, scalar) ))))
3131);
3232#undef a
33- #undef vec
33+ #undef scalar
3434#include <nbl/builtin/hlsl/concepts/__end.hlsl>
3535
3636/* ResolveAccessor is required to:
3737* - satisfy `LoadableImage` concept requirements
38- * - implement function called `calcLuma` which calculates luma from a pixel value
38+ * - implement function called `calcLuma` which calculates luma from a 3 component pixel value
3939*/
4040
4141template<typename T, typename VectorScalarType, int32_t Dims>
@@ -50,9 +50,9 @@ struct ResolveAccessorAdaptor
5050
5151 RWTexture2DArray <float32_t4> cascade;
5252
53- float32_t calcLuma (in float32_t3 col)
53+ float32_t calcLuma (NBL_REF_ARG ( float32_t3) col)
5454 {
55- return hlsl::dot<float32_t3>(hlsl:: transpose (colorspace::scRGBtoXYZ )[1 ], col);
55+ return hlsl::dot<float32_t3>(colorspace::scRGB:: ToXYZ ( )[1 ], col);
5656 }
5757
5858 template<typename OutputScalarType, int32_t Dimension>
@@ -69,10 +69,11 @@ struct ResolveAccessorAdaptor
6969 }
7070};
7171
72- template<typename CascadeAccessor, typename OutputColorType> // NBL_PRIMARY_REQUIRES(ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
72+ template<typename CascadeAccessor, typename OutputColorTypeVec NBL_PRIMARY_REQUIRES (concepts::Vector<OutputColorTypeVec> && ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
7373struct Resolver
7474{
75- using output_type = OutputColorType;
75+ using output_type = OutputColorTypeVec;
76+ using scalar_t = typename vector_traits<output_type>::scalar_type;
7677
7778 struct CascadeSample
7879 {
@@ -91,13 +92,15 @@ struct Resolver
9192
9293 output_type operator ()(NBL_REF_ARG (CascadeAccessor) acc, const int16_t2 coord)
9394 {
94- float reciprocalBaseI = 1.f ;
95+ using scalar_t = typename vector_traits<output_type>::scalar_type;
96+
97+ scalar_t reciprocalBaseI = 1.f ;
9598 CascadeSample curr = __sampleCascade (acc, coord, 0u, reciprocalBaseI);
9699
97- float32_t3 accumulation = float32_t3 (0.0f , 0.0f , 0.0f );
98- float Emin = params.initialEmin;
100+ output_type accumulation = output_type (0.0f , 0.0f , 0.0f );
101+ scalar_t Emin = params.initialEmin;
99102
100- float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
103+ scalar_t prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
101104 for (int16_t i = 0u; i <= params.lastCascadeIndex; i++)
102105 {
103106 const bool notFirstCascade = i != 0 ;
@@ -110,13 +113,13 @@ struct Resolver
110113 next = __sampleCascade (acc, coord, int16_t (i + 1 ), reciprocalBaseI);
111114 }
112115
113- float reliability = 1.f ;
116+ scalar_t reliability = 1.f ;
114117 // sample counting-based reliability estimation
115118 if (params.reciprocalKappa <= 1.f )
116119 {
117- float localReliability = curr.normalizedCenterLuma;
120+ scalar_t localReliability = curr.normalizedCenterLuma;
118121 // reliability in 3x3 pixel block (see robustness)
119- float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
122+ scalar_t globalReliability = curr.normalizedNeighbourhoodAverageLuma;
120123 if (notFirstCascade)
121124 {
122125 localReliability += prevNormalizedCenterLuma;
@@ -130,11 +133,11 @@ struct Resolver
130133 // check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
131134 reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
132135 {
133- const float accumLuma = acc.calcLuma (accumulation);
136+ const scalar_t accumLuma = acc.calcLuma (accumulation);
134137 if (accumLuma > Emin)
135138 Emin = accumLuma;
136139
137- const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
140+ const scalar_t colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
138141
139142 reliability += colorReliability;
140143 reliability *= params.NOverKappa;
@@ -156,19 +159,18 @@ struct Resolver
156159
157160 // pseudo private stuff:
158161
159- CascadeSample __sampleCascade (NBL_REF_ARG (CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, float reciprocalBaseI)
162+ CascadeSample __sampleCascade (NBL_REF_ARG (CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, scalar_t reciprocalBaseI)
160163 {
161- typename CascadeAccessor::output_type tmp;
162164 output_type neighbourhood[9 ];
163- neighbourhood[0 ] = acc.template get<float , 2 >(coord + int16_t2 (-1 , -1 ), cascadeIndex);
164- neighbourhood[1 ] = acc.template get<float , 2 >(coord + int16_t2 (0 , -1 ), cascadeIndex);
165- neighbourhood[2 ] = acc.template get<float , 2 >(coord + int16_t2 (1 , -1 ), cascadeIndex);
166- neighbourhood[3 ] = acc.template get<float , 2 >(coord + int16_t2 (-1 , 0 ), cascadeIndex);
167- neighbourhood[4 ] = acc.template get<float , 2 >(coord + int16_t2 (0 , 0 ), cascadeIndex);
168- neighbourhood[5 ] = acc.template get<float , 2 >(coord + int16_t2 (1 , 0 ), cascadeIndex);
169- neighbourhood[6 ] = acc.template get<float , 2 >(coord + int16_t2 (-1 , 1 ), cascadeIndex);
170- neighbourhood[7 ] = acc.template get<float , 2 >(coord + int16_t2 (0 , 1 ), cascadeIndex);
171- neighbourhood[8 ] = acc.template get<float , 2 >(coord + int16_t2 (1 , 1 ), cascadeIndex);
165+ neighbourhood[0 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (-1 , -1 ), cascadeIndex).xyz ;
166+ neighbourhood[1 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (0 , -1 ), cascadeIndex).xyz ;
167+ neighbourhood[2 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (1 , -1 ), cascadeIndex).xyz ;
168+ neighbourhood[3 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (-1 , 0 ), cascadeIndex).xyz ;
169+ neighbourhood[4 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (0 , 0 ), cascadeIndex).xyz ;
170+ neighbourhood[5 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (1 , 0 ), cascadeIndex).xyz ;
171+ neighbourhood[6 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (-1 , 1 ), cascadeIndex).xyz ;
172+ neighbourhood[7 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (0 , 1 ), cascadeIndex).xyz ;
173+ neighbourhood[8 ] = acc.template get<scalar_t , 2 >(coord + int16_t2 (1 , 1 ), cascadeIndex).xyz ;
172174
173175 // numerical robustness
174176 float32_t3 excl_hood_sum = ((neighbourhood[0 ] + neighbourhood[1 ]) + (neighbourhood[2 ] + neighbourhood[3 ])) +
0 commit comments