Skip to content

Commit e60a9d9

Browse files
Superharzhenryiii
authored andcommitted
Split multi_weight_value into multi_weight_value and multi_weight_reference
1 parent 8800399 commit e60a9d9

File tree

2 files changed

+91
-62
lines changed

2 files changed

+91
-62
lines changed

include/bh_python/multi_weight.hpp

Lines changed: 91 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
namespace boost {
1313
namespace histogram {
1414

15-
template <class T>
16-
struct multi_weight_value : public boost::span<T> {
17-
using boost::span<T>::span;
1815

19-
void operator()(const boost::span<T> values) { operator+=(values); }
16+
template <class T, class BASE>
17+
struct multi_weight_base : public BASE {
18+
using BASE::BASE;
2019

2120
template <class S>
2221
bool operator==(const S values) const {
@@ -31,7 +30,30 @@ struct multi_weight_value : public boost::span<T> {
3130
return !operator==(values);
3231
}
3332

34-
void operator+=(const std::vector<T> values) { operator+=(boost::span<T>(values)); }
33+
34+
};
35+
36+
template <class T>
37+
struct multi_weight_reference : public multi_weight_base<T, boost::span<T>> {
38+
//using boost::span<T>::span;
39+
using multi_weight_base::multi_weight_base;
40+
41+
void operator()(const boost::span<T> values) { operator+=(values); }
42+
43+
//template <class S>
44+
//bool operator==(const S values) const {
45+
// if(values.size() != this->size())
46+
// return false;
47+
//
48+
// return std::equal(this->begin(), this->end(), values.begin());
49+
//}
50+
//
51+
//template <class S>
52+
//bool operator!=(const S values) const {
53+
// return !operator==(values);
54+
//}
55+
56+
//void operator+=(const std::vector<T> values) { operator+=(boost::span<T>(values)); }
3557

3658
void operator+=(const boost::span<T> values) {
3759
// template <class S>
@@ -53,13 +75,60 @@ struct multi_weight_value : public boost::span<T> {
5375
}
5476
};
5577

78+
template <class T>
79+
struct multi_weight_value : public multi_weight_base<T, std::vector<T>> {
80+
using multi_weight_base::multi_weight_base;
81+
82+
multi_weight_value(const boost::span<T> values) {
83+
this->assign(values.begin(), values.end());
84+
}
85+
multi_weight_value() = default;
86+
87+
void operator()(const boost::span<T> values) { operator+=(values); }
88+
89+
//template <class S>
90+
//bool operator==(const S values) const {
91+
// if(values.size() != this->size())
92+
// return false;
93+
//
94+
// return std::equal(this->begin(), this->end(), values.begin());
95+
//}
96+
//
97+
//template <class S>
98+
//bool operator!=(const S values) const {
99+
// return !operator==(values);
100+
//}
101+
//
102+
//void operator+=(const std::vector<T> values) { operator+=(boost::span<T>(values)); }
103+
104+
//template <class S>
105+
//void operator+=(const S values) {
106+
void operator+=(const boost::span<T> values) {
107+
if(values.size() != this->size()) {
108+
if (this->size() > 0) {
109+
throw std::runtime_error("size does not match");
110+
}
111+
this->assign(values.begin(), values.end());
112+
return;
113+
}
114+
auto it = this->begin();
115+
for(const T& x : values)
116+
*it++ += x;
117+
}
118+
119+
template <class S>
120+
void operator=(const S values) {
121+
this->assign(values.begin(), values.end());
122+
}
123+
};
124+
56125
template <class ElementType = double>
57126
class multi_weight {
58127
public:
59128
using element_type = ElementType;
60129
using value_type = multi_weight_value<element_type>;
61-
using reference = value_type;
62-
using const_reference = const value_type;
130+
using reference = multi_weight_reference<element_type>;
131+
using const_reference = const reference;
63132

64133
template <class Value, class Reference, class MWPtr>
65134
struct iterator_base
@@ -194,68 +263,32 @@ std::ostream& operator<<(std::ostream& os, const multi_weight_value<T>& v) {
194263
return os;
195264
}
196265

266+
template <class T>
267+
std::ostream& operator<<(std::ostream& os, const multi_weight_reference<T>& v) {
268+
os << "multi_weight_reference(";
269+
bool first = true;
270+
for(const T& x : v)
271+
if(first) {
272+
first = false;
273+
os << x;
274+
} else
275+
os << ", " << x;
276+
os << ")";
277+
return os;
278+
}
279+
197280
template <class T>
198281
std::ostream& operator<<(std::ostream& os, const multi_weight<T>& v) {
199282
os << "multi_weight(\n";
200283
int index = 0;
201-
for(const multi_weight_value<T>& x : v) {
284+
for(const multi_weight_reference<T>& x : v) {
202285
os << "Index " << index << ": " << x << "\n";
203286
index++;
204287
}
205288
os << ")";
206289
return os;
207290
}
208291

209-
namespace algorithm {
210-
211-
/** Compute the sum over all histogram cells (underflow/overflow included by default).
212-
213-
The implementation favors accuracy and protection against overflow over speed. If the
214-
value type of the histogram is an integral or floating point type,
215-
accumulators::sum<double> is used to compute the sum, else the original value type is
216-
used. Compilation fails, if the value type does not support operator+=. The return
217-
type is double if the value type of the histogram is integral or floating point, and
218-
the original value type otherwise.
219-
220-
If you need a different trade-off, you can write your own loop or use
221-
`std::accumulate`:
222-
```
223-
// iterate over all bins
224-
auto sum_all = std::accumulate(hist.begin(), hist.end(), 0.0);
225-
226-
// skip underflow/overflow bins
227-
double sum = 0;
228-
for (auto&& x : indexed(hist))
229-
sum += *x; // dereference accessor
230-
231-
// or:
232-
// auto ind = boost::histogram::indexed(hist);
233-
// auto sum = std::accumulate(ind.begin(), ind.end(), 0.0);
234-
```
235-
236-
@returns accumulator type or double
237-
238-
@param hist Const reference to the histogram.
239-
@param cov Iterate over all or only inner bins (optional, default: all).
240-
*/
241-
template <class A, class B>
242-
std::vector<B> sum(const histogram<A, multi_weight<B>>& hist,
243-
const coverage cov = coverage::all) {
244-
using sum_type = typename histogram<A, multi_weight<B>>::value_type;
245-
// T is arithmetic, compute sum accurately with high dynamic range
246-
std::vector<B> v(unsafe_access::storage(hist).nelem_, 0.);
247-
sum_type sum(v);
248-
if(cov == coverage::all)
249-
for(auto&& x : hist)
250-
sum += x;
251-
else
252-
// sum += x also works if sum_type::operator+=(const sum_type&) exists
253-
for(auto&& x : indexed(hist))
254-
sum += *x;
255-
return v;
256-
}
257-
258-
} // namespace algorithm
259292
} // namespace histogram
260293
} // namespace boost
261294

include/bh_python/register_histogram.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,6 @@ auto register_histogram<bh::multi_weight<double>>(py::module& m,
343343
[](const histogram_t& self, py::args& args) -> value_type {
344344
auto int_args = py::cast<std::vector<int>>(args);
345345
auto at_value = self.at(int_args);
346-
// value_type return_obj;
347-
// return_obj.insert(return_obj.end(), at_value.begin(),
348-
// at_value.end()); return_obj.assign(at_value.begin(),
349-
// at_value.end()); return return_obj;
350346
return value_type(at_value.begin(), at_value.end());
351347
})
352348

0 commit comments

Comments
 (0)