Skip to content

Commit 9ded23e

Browse files
authored
Intersection configuration (#1082)
Move the mask tolerances into dedicated config in order to clean up interfaces. Also moves intersection kernel file into the intersection folder
1 parent 48b7330 commit 9ded23e

37 files changed

+326
-306
lines changed

core/include/detray/navigation/detail/print_state.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ DETRAY_HOST inline std::string print_candidates(const state_type &state,
121121
constexpr int cw{20};
122122

123123
debug_stream << std::left << std::setw(cw) << "Overstep tol.:"
124-
<< cfg.overstep_tolerance / detray::unit<scalar_t>::um << " um"
125-
<< std::endl;
124+
<< cfg.intersection.overstep_tolerance /
125+
detray::unit<scalar_t>::um
126+
<< " um" << std::endl;
126127

127128
debug_stream << std::setw(cw) << "Track:"
128129
<< "pos: [r = " << vector::perp(track_pos)

core/include/detray/navigation/direct_navigator.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include "detray/definitions/units.hpp"
1818
#include "detray/geometry/barcode.hpp"
1919
#include "detray/navigation/intersection/intersection.hpp"
20+
#include "detray/navigation/intersection/intersection_kernel.hpp"
2021
#include "detray/navigation/intersection/ray_intersector.hpp"
21-
#include "detray/navigation/intersection_kernel.hpp"
2222
#include "detray/navigation/navigation_config.hpp"
2323
#include "detray/navigation/navigation_state.hpp"
2424
#include "detray/navigation/navigator.hpp"
@@ -196,7 +196,7 @@ class direct_navigator {
196196
}
197197

198198
assert(!navigation.get_target_barcode().is_invalid());
199-
update_intersection(track, navigation, cfg, ctx);
199+
update_intersection(track, navigation, cfg.intersection, ctx);
200200

201201
if (is_before_actor_run) {
202202
if (navigation.has_reached_candidate(navigation.target(), cfg)) {
@@ -209,7 +209,8 @@ class direct_navigator {
209209
cfg));
210210

211211
if (!navigation.no_next_external()) {
212-
update_intersection(track, navigation, cfg, ctx);
212+
update_intersection(track, navigation, cfg.intersection,
213+
ctx);
213214
}
214215

215216
DETRAY_VERBOSE_HOST_DEVICE("Update complete: On surface");
@@ -234,7 +235,8 @@ class direct_navigator {
234235
private:
235236
template <typename track_t>
236237
DETRAY_HOST_DEVICE inline void update_intersection(
237-
const track_t &track, state &navigation, const navigation::config &cfg,
238+
const track_t &track, state &navigation,
239+
const intersection::config &intr_cfg,
238240
const context_type &ctx = {}) const {
239241

240242
if (navigation.target().sf_desc.barcode().is_invalid()) {
@@ -250,11 +252,8 @@ class direct_navigator {
250252
track.pos(),
251253
static_cast<scalar_type>(navigation.direction()) *
252254
track.dir()),
253-
navigation.target(), det.transform_store(), ctx,
254-
cfg.template mask_tolerance<scalar_type>(),
255-
static_cast<scalar_type>(cfg.mask_tolerance_scalor),
256-
scalar_type{0.f},
257-
static_cast<scalar_type>(cfg.overstep_tolerance));
255+
navigation.target(), det.transform_store(), ctx, intr_cfg,
256+
scalar_type{0.f});
258257

259258
// If an intersection is not found, proceed the track with safe step
260259
// size
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/** Detray library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s)
11+
#include "detray/definitions/detail/qualifiers.hpp"
12+
#include "detray/definitions/units.hpp"
13+
14+
// System include(s)
15+
#include <ostream>
16+
17+
namespace detray::intersection {
18+
19+
/// Intersector configuration
20+
struct config {
21+
/// Tolerance on the mask 'is_inside' check:
22+
/// @{
23+
/// Minimal tolerance: ~ position uncertainty on surface
24+
float min_mask_tolerance{1e-5f * unit<float>::mm};
25+
/// Maximal tolerance: loose tolerance when still far away from surface
26+
float max_mask_tolerance{3.f * unit<float>::mm};
27+
/// Scale factor on the path used for the mask tolerance calculation
28+
float mask_tolerance_scalor{5e-2f};
29+
/// @}
30+
/// Maximal absolute path distance for a track to be considered 'on surface'
31+
float path_tolerance{1.f * unit<float>::um};
32+
/// How far behind the track position to look for candidates
33+
float overstep_tolerance{-1000.f * unit<float>::um};
34+
35+
/// Print the intersector configuration
36+
DETRAY_HOST
37+
friend std::ostream& operator<<(std::ostream& out, const config& cfg) {
38+
out << " Min. mask tolerance : "
39+
<< cfg.min_mask_tolerance / detray::unit<float>::mm << " [mm]\n"
40+
<< " Max. mask tolerance : "
41+
<< cfg.max_mask_tolerance / detray::unit<float>::mm << " [mm]\n"
42+
<< " Mask tolerance scalor : " << cfg.mask_tolerance_scalor << "\n"
43+
<< " Path tolerance : "
44+
<< cfg.path_tolerance / detray::unit<float>::um << " [um]\n"
45+
<< " Overstep tolerance : "
46+
<< cfg.overstep_tolerance / detray::unit<float>::um << " [um]\n";
47+
48+
return out;
49+
}
50+
};
51+
52+
} // namespace detray::intersection

core/include/detray/navigation/intersection_kernel.hpp renamed to core/include/detray/navigation/intersection/intersection_kernel.hpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "detray/definitions/units.hpp"
1515
#include "detray/geometry/concepts.hpp"
1616
#include "detray/navigation/intersection/intersection.hpp"
17+
#include "detray/navigation/intersection/intersection_config.hpp"
1718
#include "detray/tracks/ray.hpp"
1819
#include "detray/utils/ranges.hpp"
1920

@@ -50,11 +51,8 @@ struct intersection_initialize {
5051
const surface_t &sf_desc,
5152
const transform_container_t &contextual_transforms,
5253
const typename transform_container_t::context_type &ctx,
53-
const darray<scalar_t, 2u> &mask_tolerance = {0.f,
54-
1.f * unit<scalar_t>::mm},
55-
const scalar_t mask_tol_scalor = 0.f,
56-
const scalar_t external_mask_tolerance = 0.f,
57-
const scalar_t overstep_tol = 0.f) const {
54+
const intersection::config &cfg,
55+
const scalar_t external_mask_tolerance = 0.f) const {
5856

5957
using mask_t = typename mask_group_t::value_type;
6058
using shape_t = typename mask_t::shape;
@@ -82,9 +80,10 @@ struct intersection_initialize {
8280
assert(mask_idx < mask_group.size());
8381

8482
result = intersector.point_of_intersection(
85-
traj, ctf, mask_group[mask_idx], overstep_tol);
83+
traj, ctf, mask_group[mask_idx], cfg.overstep_tolerance);
8684
} else {
87-
result = intersector.point_of_intersection(traj, ctf, overstep_tol);
85+
result = intersector.point_of_intersection(traj, ctf,
86+
cfg.overstep_tolerance);
8887
}
8988

9089
// Check if any valid solutions were found
@@ -115,9 +114,8 @@ struct intersection_initialize {
115114
std::uint8_t n_found{0u};
116115

117116
for (std::size_t i = 0u; i < n_sol; ++i) {
118-
resolve_mask(is, traj, result[i], sf_desc, mask, ctf,
119-
mask_tolerance, mask_tol_scalor,
120-
external_mask_tolerance, overstep_tol);
117+
resolve_mask(is, traj, result[i], sf_desc, mask, ctf, cfg,
118+
external_mask_tolerance);
121119

122120
if (is.is_probably_inside()) {
123121
insert_sorted(is, is_container);
@@ -128,9 +126,8 @@ struct intersection_initialize {
128126
}
129127
}
130128
} else {
131-
resolve_mask(is, traj, result, sf_desc, mask, ctf,
132-
mask_tolerance, mask_tol_scalor,
133-
external_mask_tolerance, overstep_tol);
129+
resolve_mask(is, traj, result, sf_desc, mask, ctf, cfg,
130+
external_mask_tolerance);
134131

135132
if (is.is_probably_inside()) {
136133
insert_sorted(is, is_container);
@@ -181,11 +178,8 @@ struct intersection_update {
181178
const traj_t &traj, intersection_t &sfi,
182179
const transform_container_t &contextual_transforms,
183180
const typename transform_container_t::context_type &ctx,
184-
const darray<scalar_t, 2u> &mask_tolerance = {0.f,
185-
1.f * unit<scalar_t>::mm},
186-
const scalar_t mask_tol_scalor = 0.f,
187-
const scalar_t external_mask_tolerance = 0.f,
188-
const scalar_t overstep_tol = 0.f) const {
181+
const intersection::config &cfg,
182+
const scalar_t external_mask_tolerance = 0.f) const {
189183

190184
using mask_t = typename mask_group_t::value_type;
191185
using shape_t = typename mask_t::shape;
@@ -212,9 +206,10 @@ struct intersection_update {
212206
assert(mask_idx < mask_group.size());
213207

214208
result = intersector.point_of_intersection(
215-
traj, ctf, mask_group[mask_idx], overstep_tol);
209+
traj, ctf, mask_group[mask_idx], cfg.overstep_tolerance);
216210
} else {
217-
result = intersector.point_of_intersection(traj, ctf, overstep_tol);
211+
result = intersector.point_of_intersection(traj, ctf,
212+
cfg.overstep_tolerance);
218213
}
219214

220215
// Check if any valid solutions were found
@@ -240,13 +235,11 @@ struct intersection_update {
240235

241236
// Build the resulting intersecion(s) from the intersection point
242237
if constexpr (n_sol > 1) {
243-
resolve_mask(sfi, traj, result[0], sfi.sf_desc, mask, ctf,
244-
mask_tolerance, mask_tol_scalor,
245-
external_mask_tolerance, overstep_tol);
238+
resolve_mask(sfi, traj, result[0], sfi.sf_desc, mask, ctf, cfg,
239+
external_mask_tolerance);
246240
} else {
247-
resolve_mask(sfi, traj, result, sfi.sf_desc, mask, ctf,
248-
mask_tolerance, mask_tol_scalor,
249-
external_mask_tolerance, overstep_tol);
241+
resolve_mask(sfi, traj, result, sfi.sf_desc, mask, ctf, cfg,
242+
external_mask_tolerance);
250243
}
251244

252245
if (sfi.is_probably_inside()) {

core/include/detray/navigation/intersection/intersector_base.hpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "detray/geometry/concepts.hpp"
1717
#include "detray/geometry/detail/shape_utils.hpp"
1818
#include "detray/navigation/intersection/intersection.hpp"
19+
#include "detray/navigation/intersection/intersection_config.hpp"
1920
#include "detray/utils/invalid_values.hpp"
2021

2122
// System include(s)
@@ -76,18 +77,14 @@ struct intersector_base : public intersector_t {
7677
DETRAY_HOST_DEVICE constexpr intersection_type<surface_descr_t> operator()(
7778
const trajectory_type<other_algebra_t> &traj, const surface_descr_t &sf,
7879
const mask_t &mask, const transform3_type &trf,
79-
const darray<scalar_type, 2u> mask_tolerance =
80-
{0.f, 1.f * unit<value_type>::mm},
81-
const scalar_type mask_tol_scalor = 0.f,
82-
const scalar_type external_mask_tol = 0.f,
83-
const scalar_type overstep_tol = 0.f) const {
80+
const intersection::config &cfg = {},
81+
const scalar_type external_mask_tol = 0.f) const {
8482

85-
result_type result = call_intersector(traj, mask, trf, overstep_tol);
83+
result_type result =
84+
call_intersector(traj, mask, trf, cfg.overstep_tolerance);
8685

8786
intersection_type<surface_descr_t> is;
88-
89-
resolve_mask(is, traj, result, sf, mask, trf, mask_tolerance,
90-
mask_tol_scalor, external_mask_tol, overstep_tol);
87+
resolve_mask(is, traj, result, sf, mask, trf, cfg, external_mask_tol);
9188

9289
return is;
9390
}
@@ -100,23 +97,19 @@ struct intersector_base : public intersector_t {
10097
n_solutions>
10198
operator()(const trajectory_type<other_algebra_t> &traj,
10299
const surface_descr_t &sf, const mask_t &mask,
103-
const transform3_type &trf,
104-
const darray<scalar_type, 2u> mask_tolerance =
105-
{0.f, 100.f * unit<value_type>::um},
106-
const scalar_type mask_tol_scalor = 0.f,
107-
const scalar_type external_mask_tol = 0.f,
108-
const scalar_type overstep_tol = 0.f) const {
100+
const transform3_type &trf, const intersection::config &cfg = {},
101+
const scalar_type external_mask_tol = 0.f) const {
109102

110103
// One or both of these solutions might be invalid
111-
result_type result = call_intersector(traj, mask, trf, overstep_tol);
104+
result_type result =
105+
call_intersector(traj, mask, trf, cfg.overstep_tolerance);
112106

113107
darray<intersection_type<surface_descr_t>, n_solutions> ret;
114108

115109
for (std::size_t i = 0u; i < n_solutions; ++i) {
116110
if (detray::detail::any_of(result[i].is_valid())) {
117-
resolve_mask(ret[i], traj, result[i], sf, mask, trf,
118-
mask_tolerance, mask_tol_scalor, external_mask_tol,
119-
overstep_tol);
111+
resolve_mask(ret[i], traj, result[i], sf, mask, trf, cfg,
112+
external_mask_tol);
120113
}
121114
}
122115

@@ -135,9 +128,13 @@ struct intersector_base : public intersector_t {
135128
const scalar_type mask_tolerance,
136129
const scalar_type overstep_tol = 0.f) const {
137130

138-
return this->operator()(traj, sf, mask, trf,
139-
{mask_tolerance, mask_tolerance}, 0.f, 0.f,
140-
overstep_tol);
131+
const intersection::config intr_cfg{
132+
.min_mask_tolerance = static_cast<float>(mask_tolerance),
133+
.max_mask_tolerance = static_cast<float>(mask_tolerance),
134+
.mask_tolerance_scalor = 0.f,
135+
.overstep_tolerance = static_cast<float>(overstep_tol)};
136+
137+
return this->operator()(traj, sf, mask, trf, intr_cfg, 0.f);
141138
}
142139

143140
private:
@@ -176,23 +173,21 @@ DETRAY_HOST_DEVICE constexpr void resolve_mask(
176173
const intersection_point<algebra_t, point_t, intersection::contains_pos>
177174
&ip,
178175
const surface_descr_t sf_desc, const mask_t &mask, const transform3_t &trf,
179-
const darray<scalar_t, 2> &mask_tolerance,
180-
const scalar_t mask_tol_scalor = 0.f,
181-
const scalar_t external_mask_tolerance = 0.f,
182-
const scalar_t overstep_tol = 0.f) {
176+
const intersection::config &cfg = {},
177+
const scalar_t external_mask_tolerance = 0.f) {
183178

184179
// Mask out solutions that don't meet the overstepping tolerance (SoA)
185180
if constexpr (concepts::soa<algebra_t>) {
186181
using status_t = typename intersection_t::status_t;
187182

188-
is.status(is.path() < overstep_tol) =
183+
is.status(is.path() < cfg.overstep_tolerance) =
189184
static_cast<status_t>(intersection::status::e_outside);
190185
} else {
191186
is.set_status(intersection::status::e_outside);
192187
}
193188

194189
// Build intersection struct from test trajectory, if the distance is valid
195-
if (detray::detail::none_of(ip.path >= overstep_tol)) {
190+
if (detray::detail::none_of(ip.path >= cfg.overstep_tolerance)) {
196191
// Not a valid intersection
197192
return;
198193
}
@@ -219,21 +214,28 @@ DETRAY_HOST_DEVICE constexpr void resolve_mask(
219214
mask_t::to_local_frame3D(trf, glob_pos, traj.dir(ip.path)));
220215
}
221216

217+
scalar_t base_tol = 0.f;
218+
scalar_t ext_tol = 0.f;
219+
222220
// Tol.: scale with distance of surface to account for track bending
223-
const scalar_t base_tol = math::max(
224-
mask_tolerance[0],
225-
math::min(mask_tolerance[1], mask_tol_scalor * math::fabs(ip.path)));
221+
if (!sf_desc.is_portal()) {
222+
ext_tol = external_mask_tolerance;
223+
base_tol = math::max(
224+
static_cast<scalar_t>(cfg.min_mask_tolerance),
225+
math::min(static_cast<scalar_t>(cfg.max_mask_tolerance),
226+
static_cast<scalar_t>(cfg.mask_tolerance_scalor) *
227+
math::fabs(ip.path)));
228+
}
226229

227230
// Mask check results with and without external tolerance
228231
typename mask_t::result_type mask_check{};
229232

230233
// Intersector provides specialized local point
231234
if constexpr (std::same_as<point_t, dpoint2D<algebra_t>>) {
232-
mask_check = mask.resolve(ip.point, base_tol, external_mask_tolerance);
235+
mask_check = mask.resolve(ip.point, base_tol, ext_tol);
233236
} else {
234237
// Otherwise, let the shape transform the point to local
235-
mask_check =
236-
mask.resolve(trf, ip.point, base_tol, external_mask_tolerance);
238+
mask_check = mask.resolve(trf, ip.point, base_tol, ext_tol);
237239
}
238240

239241
// Set the less strict status first, then overwrite with more strict

0 commit comments

Comments
 (0)