Skip to content

Commit b3f6d55

Browse files
committed
Fixed ResolveAccessor concept
1 parent a40a025 commit b3f6d55

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

examples_tests

include/nbl/builtin/hlsl/rwmc/resolve.hlsl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
66
#include <nbl/builtin/hlsl/rwmc/ResolveParameters.hlsl>
77
#include <nbl/builtin/hlsl/concepts/accessors/loadable_image.hlsl>
8+
#include <nbl/builtin/hlsl/colorspace.hlsl>
9+
#include <nbl/builtin/hlsl/vector_utils/vector_traits.hlsl>
810

911
namespace nbl
1012
{
@@ -19,23 +21,21 @@ namespace rwmc
1921
// not the greatest syntax but works
2022
#define NBL_CONCEPT_PARAM_0 (a,T)
2123
#define NBL_CONCEPT_PARAM_1 (scalar,VectorScalarType)
22-
#define NBL_CONCEPT_PARAM_2 (vec,vector<VectorScalarType, Dims>)
2324
// start concept
2425
NBL_CONCEPT_BEGIN(2)
2526
// need to be defined AFTER the concept begins
2627
#define a NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
2728
#define scalar NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
28-
#define vec NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
2929
NBL_CONCEPT_END(
30-
((NBL_CONCEPT_REQ_EXPR)((a.calcLuma(vec))))
30+
((NBL_CONCEPT_REQ_EXPR)((a.calcLuma(vector<VectorScalarType, 3>(scalar, scalar, scalar)))))
3131
);
3232
#undef a
33-
#undef vec
33+
#undef scalar
3434
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
3535

3636
/* ResolveAccessor is required to:
3737
* - satisfy `LoadableImage` concept requirements
38-
* - implement function called `calcLuma` which calculates luma from a pixel value
38+
* - implement function called `calcLuma` which calculates luma from a 3 component pixel value
3939
*/
4040

4141
template<typename T, typename VectorScalarType, int32_t Dims>
@@ -50,9 +50,9 @@ struct ResolveAccessorAdaptor
5050

5151
RWTexture2DArray<float32_t4> cascade;
5252

53-
float32_t calcLuma(in float32_t3 col)
53+
float32_t calcLuma(NBL_REF_ARG(float32_t3) col)
5454
{
55-
return hlsl::dot<float32_t3>(hlsl::transpose(colorspace::scRGBtoXYZ)[1], col);
55+
return hlsl::dot<float32_t3>(colorspace::scRGB::ToXYZ()[1], col);
5656
}
5757

5858
template<typename OutputScalarType, int32_t Dimension>
@@ -69,10 +69,11 @@ struct ResolveAccessorAdaptor
6969
}
7070
};
7171

72-
template<typename CascadeAccessor, typename OutputColorType> //NBL_PRIMARY_REQUIRES(ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
72+
template<typename CascadeAccessor, typename OutputColorTypeVec NBL_PRIMARY_REQUIRES(concepts::Vector<OutputColorTypeVec> && ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
7373
struct Resolver
7474
{
75-
using output_type = OutputColorType;
75+
using output_type = OutputColorTypeVec;
76+
using scalar_t = typename vector_traits<output_type>::scalar_type;
7677

7778
struct CascadeSample
7879
{
@@ -91,13 +92,15 @@ struct Resolver
9192

9293
output_type operator()(NBL_REF_ARG(CascadeAccessor) acc, const int16_t2 coord)
9394
{
94-
float reciprocalBaseI = 1.f;
95+
using scalar_t = typename vector_traits<output_type>::scalar_type;
96+
97+
scalar_t reciprocalBaseI = 1.f;
9598
CascadeSample curr = __sampleCascade(acc, coord, 0u, reciprocalBaseI);
9699

97-
float32_t3 accumulation = float32_t3(0.0f, 0.0f, 0.0f);
98-
float Emin = params.initialEmin;
100+
output_type accumulation = output_type(0.0f, 0.0f, 0.0f);
101+
scalar_t Emin = params.initialEmin;
99102

100-
float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
103+
scalar_t prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
101104
for (int16_t i = 0u; i <= params.lastCascadeIndex; i++)
102105
{
103106
const bool notFirstCascade = i != 0;
@@ -110,13 +113,13 @@ struct Resolver
110113
next = __sampleCascade(acc, coord, int16_t(i + 1), reciprocalBaseI);
111114
}
112115

113-
float reliability = 1.f;
116+
scalar_t reliability = 1.f;
114117
// sample counting-based reliability estimation
115118
if (params.reciprocalKappa <= 1.f)
116119
{
117-
float localReliability = curr.normalizedCenterLuma;
120+
scalar_t localReliability = curr.normalizedCenterLuma;
118121
// reliability in 3x3 pixel block (see robustness)
119-
float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
122+
scalar_t globalReliability = curr.normalizedNeighbourhoodAverageLuma;
120123
if (notFirstCascade)
121124
{
122125
localReliability += prevNormalizedCenterLuma;
@@ -130,11 +133,11 @@ struct Resolver
130133
// check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
131134
reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
132135
{
133-
const float accumLuma = acc.calcLuma(accumulation);
136+
const scalar_t accumLuma = acc.calcLuma(accumulation);
134137
if (accumLuma > Emin)
135138
Emin = accumLuma;
136139

137-
const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
140+
const scalar_t colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
138141

139142
reliability += colorReliability;
140143
reliability *= params.NOverKappa;
@@ -156,19 +159,18 @@ struct Resolver
156159

157160
// pseudo private stuff:
158161

159-
CascadeSample __sampleCascade(NBL_REF_ARG(CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, float reciprocalBaseI)
162+
CascadeSample __sampleCascade(NBL_REF_ARG(CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, scalar_t reciprocalBaseI)
160163
{
161-
typename CascadeAccessor::output_type tmp;
162164
output_type neighbourhood[9];
163-
neighbourhood[0] = acc.template get<float, 2>(coord + int16_t2(-1, -1), cascadeIndex);
164-
neighbourhood[1] = acc.template get<float, 2>(coord + int16_t2(0, -1), cascadeIndex);
165-
neighbourhood[2] = acc.template get<float, 2>(coord + int16_t2(1, -1), cascadeIndex);
166-
neighbourhood[3] = acc.template get<float, 2>(coord + int16_t2(-1, 0), cascadeIndex);
167-
neighbourhood[4] = acc.template get<float, 2>(coord + int16_t2(0, 0), cascadeIndex);
168-
neighbourhood[5] = acc.template get<float, 2>(coord + int16_t2(1, 0), cascadeIndex);
169-
neighbourhood[6] = acc.template get<float, 2>(coord + int16_t2(-1, 1), cascadeIndex);
170-
neighbourhood[7] = acc.template get<float, 2>(coord + int16_t2(0, 1), cascadeIndex);
171-
neighbourhood[8] = acc.template get<float, 2>(coord + int16_t2(1, 1), cascadeIndex);
165+
neighbourhood[0] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, -1), cascadeIndex).xyz;
166+
neighbourhood[1] = acc.template get<scalar_t, 2>(coord + int16_t2(0, -1), cascadeIndex).xyz;
167+
neighbourhood[2] = acc.template get<scalar_t, 2>(coord + int16_t2(1, -1), cascadeIndex).xyz;
168+
neighbourhood[3] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, 0), cascadeIndex).xyz;
169+
neighbourhood[4] = acc.template get<scalar_t, 2>(coord + int16_t2(0, 0), cascadeIndex).xyz;
170+
neighbourhood[5] = acc.template get<scalar_t, 2>(coord + int16_t2(1, 0), cascadeIndex).xyz;
171+
neighbourhood[6] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, 1), cascadeIndex).xyz;
172+
neighbourhood[7] = acc.template get<scalar_t, 2>(coord + int16_t2(0, 1), cascadeIndex).xyz;
173+
neighbourhood[8] = acc.template get<scalar_t, 2>(coord + int16_t2(1, 1), cascadeIndex).xyz;
172174

173175
// numerical robustness
174176
float32_t3 excl_hood_sum = ((neighbourhood[0] + neighbourhood[1]) + (neighbourhood[2] + neighbourhood[3])) +

0 commit comments

Comments
 (0)