Skip to content

Commit cfd7e3e

Browse files
committed
Limit sizes of frequency and 'missed' subtables.
1 parent bd3c0a6 commit cfd7e3e

File tree

2 files changed

+273
-28
lines changed

2 files changed

+273
-28
lines changed

include/huffman.tpp

Lines changed: 259 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Endpoints = google::protobuf::RepeatedField<google::protobuf::int64>;
2222
using Missed = google::protobuf::RepeatedField<google::protobuf::int64>;
2323
using Frequencies =
2424
google::protobuf::Map<google::protobuf::uint64, google::protobuf::uint64>;
25+
using SubtableSizes = google::protobuf::RepeatedField<google::protobuf::uint64>;
2526

2627
} // namespace
2728

@@ -176,6 +177,221 @@ void HuffmanCode<Symbol>::recursively_set_codewords(
176177
}
177178
}
178179

180+
namespace {
181+
182+
//! Maximum number of elements per frequency/missed subtable.
183+
inline constexpr std::size_t SUBTABLE_MAX_SIZE = 1 << 20;
184+
185+
//! A logical table split into one or more subtables of moderate size.
186+
//!
187+
//! The logical table can be read by chaining the subtables.
188+
template <typename Message, typename It> struct Supertable {
189+
// The beginning and size of a subtable.
190+
using Segment = std::pair<It, std::size_t>;
191+
192+
//! Constructor.
193+
//!
194+
//! Construct an 'empty' `Supertable`. The data members will be given the
195+
//! right sizes, but for the most part they will not populated. That is left
196+
//! to derived class constructors or callers.
197+
//!
198+
//!\param nelements Total number of subtable entries.
199+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
200+
//! `pb::HuffmanHeader`). This field will be written to.
201+
Supertable(const std::size_t nelements, SubtableSizes &nbytes_subtables)
202+
: nsubtables((nelements + SUBTABLE_MAX_SIZE - 1) / SUBTABLE_MAX_SIZE),
203+
subtables(nsubtables), segments(nsubtables),
204+
nbytes_subtables(nbytes_subtables) {
205+
nbytes_subtables.Resize(nsubtables, 0);
206+
207+
for (std::size_t i = 0; i + 1 < nsubtables; ++i) {
208+
segments.at(i).second = SUBTABLE_MAX_SIZE;
209+
}
210+
if (nsubtables) {
211+
// If `nelements` is an exact multiple of `SUBTABLE_MAX_SIZE` and not
212+
// zero, we need this last size to be `SUBTABLE_MAX_SIZE`, not `0`. If
213+
// `nelements` is zero, we won't be executing this statement.
214+
segments.back().second = nelements % SUBTABLE_MAX_SIZE
215+
? nelements % SUBTABLE_MAX_SIZE
216+
: SUBTABLE_MAX_SIZE;
217+
}
218+
}
219+
220+
//! Constructor.
221+
//!
222+
//! Construct a `Supertable` from a collection of parsed messages. This
223+
//! constructor leaves `segments` uninitialized. This is because `Supertable`
224+
//! doesn't know which field of `Message` is the subtable.
225+
//!
226+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
227+
//! `pb::HuffmanHeader`).
228+
//!\param window Window into buffer containing messages to be parsed.
229+
Supertable(SubtableSizes &nbytes_subtables, BufferWindow &window)
230+
: nsubtables(nbytes_subtables.size()), subtables(nsubtables),
231+
segments(nsubtables), nbytes_subtables(nbytes_subtables) {
232+
for (std::size_t i = 0; i < nsubtables; ++i) {
233+
subtables.at(i) = read_message<Message>(window, nbytes_subtables.Get(i));
234+
}
235+
}
236+
237+
//! Calculate and store the sizes in bytes of the subtables.
238+
//!
239+
//! This function should be called once the subtables are populated.
240+
void calculate_nbytes_subtables() {
241+
for (std::size_t i = 0; i < nsubtables; ++i) {
242+
nbytes_subtables.Set(i, subtables.at(i).ByteSize());
243+
}
244+
}
245+
246+
//! Calculate the total size in bytes of the subtables.
247+
//!
248+
//! This function assumes no changes have been made to the subtables since the
249+
//! last call to `calculate_nbytes_subtables`.
250+
std::size_t ByteSize() const {
251+
return std::accumulate(nbytes_subtables.begin(), nbytes_subtables.end(),
252+
static_cast<std::size_t>(0));
253+
}
254+
255+
void SerializeToArray(void *const p, const std::size_t n) const {
256+
unsigned char *const p_ = reinterpret_cast<unsigned char *>(p);
257+
std::size_t total = 0;
258+
for (std::size_t i = 0; i < nsubtables; ++i) {
259+
const Message &subtable = subtables.at(i);
260+
const google::protobuf::uint64 nbytes_subtable = nbytes_subtables.Get(i);
261+
262+
subtable.SerializeToArray(p_ + total, nbytes_subtable);
263+
total += nbytes_subtable;
264+
}
265+
if (total != n) {
266+
throw std::invalid_argument("serialization buffer size incorrect");
267+
}
268+
}
269+
270+
//! Number of subtables.
271+
std::size_t nsubtables;
272+
273+
//! Subtables.
274+
//!
275+
//! It might be better to name this member 'messages.' Elsewhere we use
276+
//! 'subtable' to refer to the fields of the messages containing the
277+
//! supertable elements. Using that vocabulary, a `pb::FrequencySubtable`
278+
//! would be a message while its `frequencies` field would be the subtable.
279+
std::vector<Message> subtables;
280+
281+
//! Segments for a concatenated subtable chain.
282+
//!
283+
//! A `Chain<std::vector<Segment>::iterator>` can be constructed from this.
284+
std::vector<Segment> segments;
285+
286+
//! Sizes in bytes of the subtables.
287+
SubtableSizes &nbytes_subtables;
288+
};
289+
290+
//! A logical frequency table split into one or more subtables of moderate size.
291+
struct FrequencySupertable
292+
: Supertable<pb::FrequencySubtable, Frequencies::iterator> {
293+
//! Constructor.
294+
//!
295+
//! Construct and populate a `FrequencySupertable` from a vector of symbol
296+
//! frequencies.
297+
//!
298+
//!\param frequencies Symbol frequencies to store in the subtables.
299+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
300+
//! `pb::HuffmanHeader`). This field will be written to.
301+
FrequencySupertable(const std::vector<std::size_t> &frequencies,
302+
SubtableSizes &nbytes_subtables)
303+
: Supertable(std::count_if(frequencies.begin(), frequencies.end(),
304+
[](const std::size_t frequency) -> bool {
305+
return frequency;
306+
}),
307+
nbytes_subtables) {
308+
// `i` is the index of the subtable we're inserting into. (Technically
309+
// we're inserting into the subtable's frequency map field rather than
310+
// the subtable itself.) `j` is the number of entries we've inserted
311+
// into subtable `i`. `k` is the index in the vector of frequencies
312+
// passed to the constructor.
313+
std::size_t k = 0;
314+
for (std::size_t i = 0; i < nsubtables; ++i) {
315+
Frequencies &frequencies_ = *subtables.at(i).mutable_frequencies();
316+
Segment &segment = segments.at(i);
317+
// How big `frequencies_` should be when we're done.
318+
const std::size_t nfrequencies_ = segment.second;
319+
for (std::size_t j = 0; j < nfrequencies_; ++k) {
320+
const std::size_t frequency = frequencies.at(k);
321+
if (frequency) {
322+
frequencies_.insert({k, frequency});
323+
++j;
324+
}
325+
}
326+
segment.first = frequencies_.begin();
327+
}
328+
329+
calculate_nbytes_subtables();
330+
}
331+
332+
//! Constructor.
333+
//!
334+
//! Construct a `FrequencySubtable` from a collection of parsed messages.
335+
//!
336+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
337+
//! `pb::HuffmanHeader`).
338+
//!\param window Window into buffer containing messages to be parsed.
339+
FrequencySupertable(SubtableSizes &nbytes_subtables, BufferWindow &window)
340+
: Supertable(nbytes_subtables, window) {
341+
for (std::size_t i = 0; i < nsubtables; ++i) {
342+
Segment &segment = segments.at(i);
343+
Frequencies &frequencies = *subtables.at(i).mutable_frequencies();
344+
345+
segment.first = frequencies.begin();
346+
segment.second = frequencies.size();
347+
}
348+
}
349+
};
350+
351+
//! A logical 'missed' table split into one or more subtables of moderate size.
352+
struct MissedSupertable : Supertable<pb::MissedSubtable, Missed::iterator> {
353+
//! Constructor.
354+
//!
355+
//! Construct an 'empty' `MissedSupertable`. It is expected that the caller
356+
//! will subsequently write to the subtables using `Chain`.
357+
//!
358+
//!\param nmissed Number of missed symbols.
359+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
360+
//! `pb::HuffmanHeader`). This field will be written to.
361+
MissedSupertable(const std::size_t nmissed, SubtableSizes &nbytes_subtables)
362+
: Supertable(nmissed, nbytes_subtables) {
363+
for (std::size_t i = 0; i < nsubtables; ++i) {
364+
Missed &missed = *subtables.at(i).mutable_missed();
365+
Segment &segment = segments.at(i);
366+
// How big `missed` should be when we're done.
367+
const std::size_t nmissed = segment.second;
368+
369+
missed.Resize(nmissed, 0);
370+
segment.first = missed.begin();
371+
}
372+
}
373+
374+
//! Constructor.
375+
//!
376+
//! Construct a `MissedSubtable` from a collection of parsed messages.
377+
//!
378+
//!\param nbytes_subtables Sizes in bytes of the subtables (field in
379+
//! `pb::HuffmanHeader`).
380+
//!\param window Window into buffer containing messages to be parsed.
381+
MissedSupertable(SubtableSizes &nbytes_subtables, BufferWindow &window)
382+
: Supertable(nbytes_subtables, window) {
383+
for (std::size_t i = 0; i < nsubtables; ++i) {
384+
Segment &segment = segments.at(i);
385+
Missed &missed = *subtables.at(i).mutable_missed();
386+
387+
segment.first = missed.begin();
388+
segment.second = missed.size();
389+
}
390+
}
391+
};
392+
393+
} // namespace
394+
179395
template <typename Symbol>
180396
MemoryBuffer<unsigned char> huffman_encode(Symbol const *const begin,
181397
const std::size_t n) {
@@ -188,7 +404,7 @@ MemoryBuffer<unsigned char> huffman_encode(Symbol const *const begin,
188404
const std::size_t nbits =
189405
std::inner_product(code.frequencies.begin(), code.frequencies.end(),
190406
lengths.begin(), static_cast<std::size_t>(0));
191-
const std::size_t nbytes = (nbits + CHAR_BIT - 1) / CHAR_BIT;
407+
const std::size_t nbytes_hit = (nbits + CHAR_BIT - 1) / CHAR_BIT;
192408

193409
pb::HuffmanHeader header;
194410
header.set_index_mapping(pb::HuffmanHeader::INCLUSIVE_RANGE);
@@ -200,23 +416,18 @@ MemoryBuffer<unsigned char> huffman_encode(Symbol const *const begin,
200416
header.add_endpoints(code.endpoints.second);
201417
header.set_nbits(nbits);
202418

203-
Frequencies &frequencies = *header.mutable_frequencies();
204-
{
205-
std::size_t i = 0;
206-
for (const std::size_t frequency : code.frequencies) {
207-
if (frequency) {
208-
frequencies.insert({i, frequency});
209-
}
210-
++i;
211-
}
212-
}
419+
FrequencySupertable frequency_supertable(
420+
code.frequencies, *header.mutable_nbytes_frequency_subtables());
421+
MissedSupertable missed_supertable(code.nmissed(),
422+
*header.mutable_nbytes_missed_subtables());
213423

214-
Missed &missed_ = *header.mutable_missed();
215-
missed_.Resize(code.nmissed(), 0);
216-
Missed::iterator missed = missed_.begin();
424+
Chain<Missed::iterator> chained_missed_supertable(missed_supertable.segments);
425+
Chain<Missed::iterator>::iterator missed = chained_missed_supertable.begin();
426+
// Now we're ready to populate the 'missed' subtables in the course of
427+
// populating the 'hit' buffer.
217428

218429
// Zero-initialize the bytes.
219-
unsigned char *const hit_ = new unsigned char[nbytes]();
430+
unsigned char *const hit_ = new unsigned char[nbytes_hit]();
220431
unsigned char *hit = hit_;
221432

222433
unsigned char offset = 0;
@@ -249,8 +460,18 @@ MemoryBuffer<unsigned char> huffman_encode(Symbol const *const begin,
249460
}
250461
}
251462

463+
// We're done writing to the 'missed' subtables, so we can now calculate their
464+
// serialized sizes. We need to do this before calling
465+
// `missed_supertable.ByteSize`.
466+
missed_supertable.calculate_nbytes_subtables();
467+
252468
const std::uint_least64_t nheader = header.ByteSize();
253-
MemoryBuffer<unsigned char> out(HEADER_SIZE_SIZE + nheader + nbytes);
469+
const std::size_t nbytes_frequency_supertable =
470+
frequency_supertable.ByteSize();
471+
const std::size_t nbytes_missed_supertable = missed_supertable.ByteSize();
472+
MemoryBuffer<unsigned char> out(HEADER_SIZE_SIZE + nheader +
473+
nbytes_frequency_supertable +
474+
nbytes_missed_supertable + nbytes_hit);
254475
{
255476
unsigned char *p = out.data.get();
256477
const std::array<unsigned char, HEADER_SIZE_SIZE> nheader_ =
@@ -261,8 +482,14 @@ MemoryBuffer<unsigned char> huffman_encode(Symbol const *const begin,
261482
header.SerializeToArray(p, nheader);
262483
p += nheader;
263484

264-
std::copy(hit_, hit_ + nbytes, p);
265-
p += nbytes;
485+
frequency_supertable.SerializeToArray(p, nbytes_frequency_supertable);
486+
p += nbytes_frequency_supertable;
487+
488+
missed_supertable.SerializeToArray(p, nbytes_missed_supertable);
489+
p += nbytes_missed_supertable;
490+
491+
std::copy(hit_, hit_ + nbytes_hit, p);
492+
p += nbytes_hit;
266493
}
267494

268495
delete[] hit_;
@@ -283,19 +510,24 @@ MemoryBuffer<Symbol> huffman_decode(const MemoryBuffer<unsigned char> &buffer) {
283510
if (endpoints_.size() != 2) {
284511
throw std::runtime_error("received an unexpected number of endpoints");
285512
}
286-
const std::pair<std::size_t, std::size_t> endpoints(endpoints_.Get(0),
287-
endpoints_.Get(1));
513+
const std::pair<Symbol, Symbol> endpoints(endpoints_.Get(0),
514+
endpoints_.Get(1));
288515

289516
if (header.codeword_mapping() != pb::HuffmanHeader::INDEX_FREQUENCY_PAIRS) {
290517
throw std::runtime_error("unrecognized Huffman codeword mapping");
291518
}
292-
const Frequencies &frequencies_ = header.frequencies();
519+
FrequencySupertable frequency_supertable(
520+
*header.mutable_nbytes_frequency_subtables(), window);
521+
Chain<Frequencies::iterator> chained_frequency_supertable(
522+
frequency_supertable.segments);
293523

294524
if (header.missed_encoding() != pb::HuffmanHeader::LITERAL) {
295525
throw std::runtime_error("unrecognized Huffman missed buffer encoding");
296526
}
297-
const Missed &missed_ = header.missed();
298-
Missed::const_iterator missed = missed_.cbegin();
527+
MissedSupertable missed_supertable(*header.mutable_nbytes_missed_subtables(),
528+
window);
529+
Chain<Missed::iterator> chained_missed_supertable(missed_supertable.segments);
530+
Chain<Missed::iterator>::iterator missed = chained_missed_supertable.begin();
299531

300532
if (header.hit_encoding() != pb::HuffmanHeader::RUN_TOGETHER) {
301533
throw std::runtime_error("unrecognized Huffman hit buffer encoding");
@@ -308,8 +540,9 @@ MemoryBuffer<Symbol> huffman_decode(const MemoryBuffer<unsigned char> &buffer) {
308540
"number of bytes in hit buffer");
309541
}
310542

311-
const HuffmanCode<Symbol> code(endpoints, frequencies_.begin(),
312-
frequencies_.end());
543+
const HuffmanCode<Symbol> code(endpoints,
544+
chained_frequency_supertable.begin(),
545+
chained_frequency_supertable.end());
313546
// TODO: Maybe add a member function for this.
314547
const std::size_t nout =
315548
std::accumulate(code.frequencies.begin(), code.frequencies.end(),
@@ -332,7 +565,7 @@ MemoryBuffer<Symbol> huffman_decode(const MemoryBuffer<unsigned char> &buffer) {
332565
*q++ = decoded.first ? decoded.second : *missed++;
333566
}
334567
assert(nbits_read == nbits);
335-
assert(missed == missed_.cend());
568+
assert(missed == chained_missed_supertable.end());
336569

337570
return out;
338571
}

src/mgard.proto

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,24 @@ message HuffmanHeader {
189189

190190
// Minimum and maximum symbols eligible for codewords.
191191
repeated sint64 endpoints = 5;
192+
// Sizes in bytes of serialized `FrequencySubtable`s to followw.
193+
repeated uint64 nbytes_frequency_subtables = 6;
194+
// Sizes in bytes of serialized `MissedSubtable`s to follow.
195+
repeated uint64 nbytes_missed_subtables = 7;
196+
// Size in bits of the hit buffer to follow.
197+
uint64 nbits = 8;
198+
}
199+
200+
// One or more of these will follow a `HuffmanHeader`.
201+
message FrequencySubtable {
192202
// Index–frequency pairs for frequency table.
193203
map<uint64, uint64> frequencies = 6;
204+
}
205+
206+
// One or more of these will follow the `FrequencySubtable`s after a `HuffmanHeader`.
207+
message MissedSubtable {
194208
// Encountered symbols that were not assigned codewords.
195209
repeated sint64 missed = 7;
196-
// Size of the hit buffer in bits.
197-
uint64 nbits = 8;
198210
}
199211

200212
message Device {

0 commit comments

Comments
 (0)