Skip to content

Commit eb6a3e6

Browse files
authored
Protect from_hepevt against invalid parents/children record (#47)
This protects from_hepevt from invalid parent/children ranges. Previously the code would abort or even crash on such input. Now it raises a Python RuntimeError.
1 parent 968be50 commit eb6a3e6

File tree

2 files changed

+77
-14
lines changed

2 files changed

+77
-14
lines changed

src/from_hepevt.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@
88
#include <utility>
99
#include <vector>
1010

11-
// template <>
12-
// struct std::hash<std::pair<int, int>> {
13-
// std::size_t operator()(const std::pair<int, int>& p) const noexcept {
14-
// auto h1 = std::hash<int>{}(p.first);
15-
// auto h2 = std::hash<int>{}(p.second);
16-
// return h1 ^ (h2 << 1); // or use boost::hash_combine
17-
// }
18-
// };
19-
2011
template <>
2112
struct std::less<std::pair<int, int>> {
2213
bool operator()(const std::pair<int, int>& a,
@@ -35,9 +26,10 @@ void normalize(int& m1, int& m2) {
3526
// m1 < m2, both > 0: interaction
3627
// m2 < m1, both > 0: same, needs swapping
3728

38-
if (m1 > 0 && m2 > 0 && m2 < m1) std::swap(m1, m2);
39-
40-
if (m1 > 0 && m2 == 0) m2 = m1;
29+
if (m1 > 0 && m2 == 0)
30+
m2 = m1;
31+
else if (m2 < m1)
32+
std::swap(m1, m2);
4133

4234
--m1; // fortran to c index
4335
}
@@ -102,14 +94,28 @@ void connect_parents_and_children(GenEvent& event, bool parents,
10294

10395
int m1 = vi.first.first;
10496
int m2 = vi.first.second;
97+
10598
// there must be at least one parent or child when we arrive here...
10699
normalize(m1, m2);
107-
assert(m1 < m2);
100+
assert(m1 < m2); // postcondition after normalize
101+
102+
if (m1 < 0 || m2 > n) {
103+
std::ostringstream os;
104+
os << "invalid " << (parents ? "parents" : "children") << " range for vertex "
105+
<< event.vertices().size() << " [" << m1 << ", " << m2
106+
<< ") total number of particles " << n;
107+
throw std::runtime_error(os.str().c_str());
108+
}
108109

109110
// ...with at least one child or parent
110111
const auto& co = vi.second;
111-
assert(!co.empty());
112112

113+
if (co.empty()) {
114+
std::ostringstream os;
115+
os << "invalid empty " << (!parents ? "parents" : "children")
116+
<< " list for vertex " << event.vertices().size();
117+
throw std::runtime_error(os.str().c_str());
118+
}
113119
FourVector pos;
114120
if (has_vertex) {
115121
// we assume this is a production vertex

tests/test_from_hepevt.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pyhepmc as hep
2+
import numpy as np
3+
import pytest
4+
5+
6+
def test_no_vertex_info():
7+
px = py = pz = en = m = np.linspace(0, 1, 4)
8+
9+
pid = np.arange(4) + 1
10+
sta = np.zeros(4, dtype=np.int32)
11+
parents = [(0, 0), (1, 1), (0, 0), (0, 0)]
12+
hev = hep.GenEvent()
13+
hev.from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
14+
assert len(hev.vertices) == 1
15+
assert len(hev.particles) == 4
16+
17+
18+
def test_parents_range_exceeding_particle_range():
19+
px = py = pz = en = m = np.linspace(0, 1, 6)
20+
pid = np.arange(6) + 1
21+
sta = np.zeros(6, dtype=np.int32)
22+
parents = [(0, 0), (1, 1), (2, 0), (3, 5), (4, 10), (3, 5)]
23+
with pytest.raises(RuntimeError):
24+
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
25+
26+
27+
def test_invalid_length_of_parents():
28+
px = py = pz = en = m = np.linspace(0, 1, 3)
29+
pid = np.arange(3) + 1
30+
sta = np.zeros(3, dtype=np.int32)
31+
parents = [(0, 0), (1, 2)]
32+
with pytest.raises(RuntimeError):
33+
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
34+
35+
36+
def test_inverted_parents_range():
37+
px = py = pz = en = m = vx = vy = vz = vt = np.linspace(0, 1, 4)
38+
pid = np.arange(4) + 1
39+
sta = np.zeros(4, dtype=np.int32)
40+
# inverted range is not an error (2, 1) will be converted to (1, 2)
41+
parents = [(0, 0), (2, 1), (3, 3), (3, 3)]
42+
hev = hep.GenEvent()
43+
hev.from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
44+
expected = [[0, 1], [2]]
45+
got = [[p.id - 1 for p in v.particles_in] for v in hev.vertices]
46+
assert expected == got
47+
48+
49+
@pytest.mark.parametrize("bad", ([-4, 1], [1, -4]))
50+
def test_negative_parents_range(bad):
51+
px = py = pz = en = m = vx = vy = vz = vt = np.linspace(0, 1, 4)
52+
pid = np.arange(4) + 1
53+
sta = np.zeros(4, dtype=np.int32)
54+
# inverted range is not an error (2, 1) will be converted to (1, 2)
55+
parents = [(0, 0), bad, (3, 3), (3, 3)]
56+
with pytest.raises(RuntimeError):
57+
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)

0 commit comments

Comments
 (0)