7
7
from collections import defaultdict
8
8
from collections .abc import Callable
9
9
from dataclasses import dataclass
10
+ from datetime import datetime
10
11
from typing import Any , Generic , TypeVar
11
12
12
13
import sentry_sdk
@@ -30,25 +31,25 @@ class UnassignedPartitionError(Exception):
30
31
31
32
@dataclass
32
33
class WorkItem (Generic [T ]):
33
- """Work item that includes the original message for offset tracking."""
34
+ """Work item for processing with offset tracking information ."""
34
35
35
36
partition : Partition
36
37
offset : int
38
+ timestamp : datetime
37
39
result : T
38
- message : Message [KafkaPayload | FilteredPayload ]
39
40
40
41
41
42
class OffsetTracker :
42
43
"""
43
44
Tracks outstanding offsets and determines which offsets are safe to commit.
44
45
45
- - Tracks offsets per partition
46
+ - Tracks offsets per partition with their timestamps
46
47
- Only commits offsets when all prior offsets are processed
47
48
- Thread-safe for concurrent access with per-partition locks
48
49
"""
49
50
50
51
def __init__ (self ) -> None :
51
- self .all_offsets : dict [Partition , set [int ]] = defaultdict (set )
52
+ self .all_offsets : dict [Partition , dict [int , datetime ]] = defaultdict (dict )
52
53
self .outstanding : dict [Partition , set [int ]] = defaultdict (set )
53
54
self .last_committed : dict [Partition , int ] = {}
54
55
self .partition_locks : dict [Partition , threading .Lock ] = {}
@@ -57,15 +58,15 @@ def _get_partition_lock(self, partition: Partition) -> threading.Lock:
57
58
"""Get the lock for a partition."""
58
59
return self .partition_locks [partition ]
59
60
60
- def add_offset (self , partition : Partition , offset : int ) -> None :
61
+ def add_offset (self , partition : Partition , offset : int , timestamp : datetime ) -> None :
61
62
"""Record that we've started processing an offset."""
62
63
if partition not in self .partition_locks :
63
64
raise UnassignedPartitionError (
64
65
f"Partition { partition } is not assigned to this consumer"
65
66
)
66
67
67
68
with self ._get_partition_lock (partition ):
68
- self .all_offsets [partition ]. add ( offset )
69
+ self .all_offsets [partition ][ offset ] = timestamp
69
70
self .outstanding [partition ].add (offset )
70
71
71
72
def complete_offset (self , partition : Partition , offset : int ) -> None :
@@ -76,11 +77,12 @@ def complete_offset(self, partition: Partition, offset: int) -> None:
76
77
with self ._get_partition_lock (partition ):
77
78
self .outstanding [partition ].discard (offset )
78
79
79
- def get_committable_offsets (self ) -> dict [Partition , int ]:
80
+ def get_committable_offsets (self ) -> dict [Partition , tuple [ int , datetime ] ]:
80
81
"""
81
82
Get the highest offset per partition that can be safely committed.
82
83
83
- For each partition, finds the highest contiguous offset that has been processed.
84
+ For each partition, finds the highest contiguous offset that has been processed,
85
+ and returns the timestamp of the oldest offset in the contiguous set.
84
86
"""
85
87
committable = {}
86
88
for partition in list (self .all_offsets .keys ()):
@@ -92,20 +94,26 @@ def get_committable_offsets(self) -> dict[Partition, int]:
92
94
outstanding = self .outstanding [partition ]
93
95
last_committed = self .last_committed .get (partition , - 1 )
94
96
95
- min_offset = min (all_offsets )
96
- max_offset = max (all_offsets )
97
+ offset_keys = set (all_offsets .keys ())
98
+ min_offset = min (offset_keys )
99
+ max_offset = max (offset_keys )
97
100
98
101
start = max (last_committed + 1 , min_offset )
99
102
100
103
highest_committable = last_committed
104
+ oldest_timestamp = None
105
+
101
106
for offset in range (start , max_offset + 1 ):
102
107
if offset in all_offsets and offset not in outstanding :
103
108
highest_committable = offset
109
+ timestamp = all_offsets [offset ]
110
+ if oldest_timestamp is None or timestamp < oldest_timestamp :
111
+ oldest_timestamp = timestamp
104
112
else :
105
113
break
106
114
107
- if highest_committable > last_committed :
108
- committable [partition ] = highest_committable
115
+ if highest_committable > last_committed and oldest_timestamp :
116
+ committable [partition ] = ( highest_committable , oldest_timestamp )
109
117
110
118
return committable
111
119
@@ -114,7 +122,9 @@ def mark_committed(self, partition: Partition, offset: int) -> None:
114
122
with self ._get_partition_lock (partition ):
115
123
self .last_committed [partition ] = offset
116
124
# Remove all offsets <= committed offset
117
- self .all_offsets [partition ] = {o for o in self .all_offsets [partition ] if o > offset }
125
+ self .all_offsets [partition ] = {
126
+ k : v for k , v in self .all_offsets [partition ].items () if k > offset
127
+ }
118
128
119
129
def clear (self ) -> None :
120
130
"""Clear all offset tracking state."""
@@ -193,13 +203,15 @@ def __init__(
193
203
self ,
194
204
result_processor : Callable [[str , T ], None ],
195
205
identifier : str ,
206
+ consumer_group : str ,
196
207
num_queues : int = 20 ,
197
208
commit_interval : float = 1.0 ,
198
209
) -> None :
199
210
self .result_processor = result_processor
200
211
self .identifier = identifier
201
212
self .num_queues = num_queues
202
213
self .commit_interval = commit_interval
214
+ self .consumer_group = consumer_group
203
215
self .offset_tracker = OffsetTracker ()
204
216
self .queues : list [queue .Queue [WorkItem [T ]]] = []
205
217
self .workers : list [OrderedQueueWorker [T ]] = []
@@ -239,9 +251,21 @@ def _commit_loop(self) -> None:
239
251
len (committable ),
240
252
tags = {"identifier" : self .identifier },
241
253
)
242
-
243
- self .commit_function (committable )
244
- for partition , offset in committable .items ():
254
+ for partition , (offset , oldest_timestamp ) in committable .items ():
255
+ metrics .timing (
256
+ "arroyo.consumer.latency" ,
257
+ time .time () - oldest_timestamp .timestamp (),
258
+ tags = {
259
+ "partition" : partition .index ,
260
+ "kafka_topic" : partition .topic .name ,
261
+ "consumer_group" : self .consumer_group ,
262
+ },
263
+ )
264
+
265
+ self .commit_function (
266
+ {partition : offset for partition , (offset , _ ) in committable .items ()}
267
+ )
268
+ for partition , (offset , _ ) in committable .items ():
245
269
self .offset_tracker .mark_committed (partition , offset )
246
270
except Exception :
247
271
logger .exception ("Error in commit loop" )
@@ -257,7 +281,9 @@ def submit(self, group_key: str, work_item: WorkItem[T]) -> None:
257
281
Submit a work item to the appropriate queue.
258
282
"""
259
283
try :
260
- self .offset_tracker .add_offset (work_item .partition , work_item .offset )
284
+ self .offset_tracker .add_offset (
285
+ work_item .partition , work_item .offset , work_item .timestamp
286
+ )
261
287
except UnassignedPartitionError :
262
288
logger .exception (
263
289
"Received message for unassigned partition, skipping" ,
@@ -400,10 +426,11 @@ def submit(self, message: Message[KafkaPayload | FilteredPayload]) -> None:
400
426
assert isinstance (message .value , BrokerValue )
401
427
partition = message .value .partition
402
428
offset = message .value .offset
429
+ timestamp = message .value .timestamp
403
430
404
431
if result is None :
405
432
try :
406
- self .queue_pool .offset_tracker .add_offset (partition , offset )
433
+ self .queue_pool .offset_tracker .add_offset (partition , offset , timestamp )
407
434
self .queue_pool .offset_tracker .complete_offset (partition , offset )
408
435
except UnassignedPartitionError :
409
436
pass
@@ -414,21 +441,24 @@ def submit(self, message: Message[KafkaPayload | FilteredPayload]) -> None:
414
441
work_item = WorkItem (
415
442
partition = partition ,
416
443
offset = offset ,
444
+ timestamp = timestamp ,
417
445
result = result ,
418
- message = message ,
419
446
)
420
447
421
448
self .queue_pool .submit (group_key , work_item )
422
449
423
450
except Exception :
424
451
logger .exception ("Error submitting message to queue" )
425
452
if isinstance (message .value , BrokerValue ):
426
- self .queue_pool .offset_tracker .add_offset (
427
- message .value .partition , message .value .offset
428
- )
429
- self .queue_pool .offset_tracker .complete_offset (
430
- message .value .partition , message .value .offset
431
- )
453
+ try :
454
+ self .queue_pool .offset_tracker .add_offset (
455
+ message .value .partition , message .value .offset , message .value .timestamp
456
+ )
457
+ self .queue_pool .offset_tracker .complete_offset (
458
+ message .value .partition , message .value .offset
459
+ )
460
+ except UnassignedPartitionError :
461
+ pass
432
462
433
463
def poll (self ) -> None :
434
464
stats = self .queue_pool .get_stats ()
0 commit comments