Skip to content

Commit 57b138a

Browse files
Merge pull request #64 from JanekEbb/master
fix assess in DynamicTimeSeriesBucket when max_total_size is used
2 parents 0c32ff7 + 84b1b8b commit 57b138a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

lazy_dataset/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3403,7 +3403,13 @@ def is_completed(self):
34033403

34043404
def assess(self, example):
34053405
seq_len = self.len_key(example)
3406-
return self.lower_bound <= seq_len <= self.upper_bound
3406+
return (
3407+
(self.lower_bound <= seq_len <= self.upper_bound)
3408+
and (
3409+
(self.max_total_size is None)
3410+
or ((len(self.data) + 1) * max(self.max_len, seq_len) <= self.max_total_size)
3411+
)
3412+
)
34073413

34083414
def _append(self, example):
34093415
super()._append(example)

tests/test_bucket.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,16 @@ def test_bucket():
1919
assert dynamic_batched_buckets == [
2020
[10, 5], [7, 8], [1, 2], [4, 3], [6, 9], [20], [1]
2121
]
22+
23+
24+
def test_max_total_size():
25+
examples = [6, 7, 9, 5, 6, 3, 7, 4]
26+
examples = {str(j): i for j, i in enumerate(examples)}
27+
ds = lazy_dataset.new(examples)
28+
29+
dynamic_batched_buckets = list(ds.batch_dynamic_time_series_bucket(
30+
batch_size=3, len_key=lambda x: x, max_padding_rate=0.9, max_total_size=21,
31+
))
32+
assert dynamic_batched_buckets == [
33+
[6, 7, 5], [9, 6], [3, 7, 4]
34+
]

0 commit comments

Comments
 (0)