Skip to content

Commit f70df59

Browse files
committed
Add arrow adapter to stream record batches zero-copy into csp
Signed-off-by: Arham Chopra <[email protected]>
1 parent d30d188 commit f70df59

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed

csp/adapters/arrow.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import itertools
2+
import queue
3+
import threading
4+
from typing import Iterable, List, Optional
5+
6+
import pyarrow as pa
7+
import pyarrow.compute as pc
8+
import pyarrow.parquet as pq
9+
10+
import csp
11+
from csp.impl.types.tstype import ts
12+
from csp.impl.wiring import py_pull_adapter_def, py_push_adapter_def
13+
14+
__all__ = [
15+
"ArrowRealtimeAdapter",
16+
"ArrowHistoricalAdapter",
17+
"accumulate_record_batches",
18+
]
19+
20+
21+
class ArrowRealtimeAdapterImpl(csp.impl.pushadapter.PushInputAdapter):
22+
"""Stream record batches in realtime into csp"""
23+
24+
def __init__(self, timeout: int, source: queue.Queue[pa.RecordBatch]):
25+
"""
26+
Args:
27+
timeout: max time in seconds to block for when waiting from results from the queue
28+
source: queue of streaming record batches, needs to be provided by the user
29+
"""
30+
self.timeout = timeout
31+
self.queue = source
32+
self._thread = None
33+
self._running = False
34+
self._exc = None
35+
super().__init__()
36+
37+
def start(self, start_time, end_time):
38+
self._thread = threading.Thread(target=self._run)
39+
self._running = True
40+
self._thread.start()
41+
42+
def stop(self):
43+
if self._running:
44+
self._running = False
45+
self._thread.join()
46+
if self._exc:
47+
raise self._exc
48+
49+
def _run(self):
50+
while self._running:
51+
try:
52+
new_batches = self.queue.get(block=True, timeout=self.timeout)
53+
self.push_tick(new_batches)
54+
except queue.Empty:
55+
# No new data loop back
56+
pass
57+
except Exception as e:
58+
self._exc = e
59+
break
60+
61+
62+
ArrowRealtimeAdapter = py_push_adapter_def(
63+
"ArrowRealtimeAdapter",
64+
ArrowRealtimeAdapterImpl,
65+
ts[List[pa.RecordBatch]],
66+
timeout=int,
67+
source=queue.Queue[pa.RecordBatch],
68+
)
69+
70+
71+
class ArrowHistoricalAdapterImpl(csp.impl.pulladapter.PullInputAdapter):
72+
"""Stream record batches from some source into csp"""
73+
74+
def __init__(
75+
self,
76+
ts_col_name: str,
77+
stream: Optional[Iterable[pa.RecordBatch]],
78+
tables: Optional[Iterable[pa.Table]],
79+
filenames: Optional[Iterable[str]],
80+
):
81+
"""
82+
Args:
83+
ts_col_name: name of column that contains the timestamp field
84+
stream: an optional iterable of record batches
85+
tables: an optional iterable for arrow tables to read from
86+
filenames: an optional iterable of parquet files to read from
87+
88+
NOTE: The user is responsible for ensuring that the data is sorted in ascending order on the 'ts_col_name' field
89+
NOTE: batches from stream, tables and filenames are iterated in that order
90+
"""
91+
assert stream or filenames or tables, "Atleast one of stream, filenames, or tables must be not None"
92+
self.stream = stream
93+
self.tables = tables
94+
self.filenames = filenames
95+
self.ts_col_name = ts_col_name
96+
super().__init__()
97+
98+
def start(self, start_time, end_time):
99+
self.start_time = start_time
100+
self.end_time = end_time
101+
102+
# Info about the last chunk of data
103+
self.last_chunk = None
104+
self.last_ts = None
105+
# No of chunks in this batch
106+
self.batch_chunks_count = 0
107+
# Iterator for iterating over the chunks in a batch
108+
self.chunk_index_iter = None
109+
# No of chunks processed till now
110+
self.processed_chunks_count = 0
111+
# current batch being processed
112+
self.batch = None
113+
# all batches processed
114+
self.finished = False
115+
# start time filtering done
116+
self.filtered_start_time = False
117+
# the starting batch with start_time filtered
118+
self.starting_batch = None
119+
120+
batch_iters = []
121+
if self.stream:
122+
batch_iters += [self.stream]
123+
124+
if self.tables:
125+
batch_iters += [table.to_batches() for table in self.tables]
126+
127+
if self.filenames:
128+
batch_iters += [pq.ParquetFile(filename).iter_batches() for filename in self.filenames]
129+
130+
self.source = itertools.chain(*batch_iters)
131+
132+
super().start(start_time, end_time)
133+
134+
def next(self):
135+
if self.finished:
136+
return None
137+
138+
# Filter out all batches which have ts < start time
139+
while not self.filtered_start_time and not self.finished:
140+
try:
141+
batch = next(self.source)
142+
if batch.num_rows != 0:
143+
# NOTE: filter might be a good option to avoid this indirect way of computing the slice,
144+
# however I am not sure if filter will be zero copy
145+
valid_indices = pc.indices_nonzero(pc.greater_equal(batch[self.ts_col_name], self.start_time))
146+
if len(valid_indices) != 0:
147+
# Slice to only get the records with ts >= start_time
148+
self.starting_batch = batch.slice(offset=valid_indices[0].as_py())
149+
self.filtered_start_time = True
150+
except StopIteration:
151+
self.finished = True
152+
153+
while not self.finished:
154+
# Process all the chunks in current batch
155+
if self.chunk_index_iter:
156+
try:
157+
start_idx, next_start_idx = next(self.chunk_index_iter)
158+
new_batches = [self.batch.slice(offset=start_idx, length=next_start_idx - start_idx)]
159+
new_ts = self.batch[self.ts_col_name][start_idx].as_py()
160+
self.processed_chunks_count += 1
161+
if self.last_chunk:
162+
if self.last_ts == new_ts:
163+
new_batches = self.last_chunk + new_batches
164+
self.last_chunk = None
165+
self.last_ts = None
166+
else:
167+
raise Exception("last_chunk and new_batches have different timestamps")
168+
169+
if self.processed_chunks_count == self.batch_chunks_count:
170+
self.last_chunk = new_batches
171+
self.last_ts = new_ts
172+
self.processed_chunks_count = 0
173+
else:
174+
if new_ts > self.end_time:
175+
self.finished = True
176+
continue
177+
return (new_ts, new_batches)
178+
except StopIteration:
179+
raise Exception("chunk_index_iter reached end, how?")
180+
181+
# Try to get a new batch of data
182+
try:
183+
if self.starting_batch:
184+
# Use the sliced batch from start_time filtering
185+
self.batch = self.starting_batch
186+
self.starting_batch = None
187+
else:
188+
# Get the next batch of data
189+
self.batch = next(self.source)
190+
if self.batch.num_rows == 0:
191+
continue
192+
193+
all_timestamps = self.batch[self.ts_col_name]
194+
unique_timestamps = all_timestamps.unique()
195+
indexes = pc.index_in(unique_timestamps, all_timestamps).to_pylist() + [self.batch.num_rows]
196+
self.chunk_index_iter = zip(indexes, indexes[1:])
197+
self.batch_chunks_count = len(unique_timestamps)
198+
starting_ts = unique_timestamps[0].as_py()
199+
if starting_ts != self.last_ts and self.last_chunk:
200+
new_batches = self.last_chunk
201+
new_ts = self.last_ts
202+
self.last_chunk = None
203+
self.last_ts = None
204+
if new_ts > self.end_time:
205+
self.finished = True
206+
continue
207+
return (new_ts, new_batches)
208+
except StopIteration:
209+
self.finished = True
210+
if self.last_chunk:
211+
if self.last_ts > self.end_time:
212+
continue
213+
return (self.last_ts, self.last_chunk)
214+
return None
215+
216+
217+
ArrowHistoricalAdapter = py_pull_adapter_def(
218+
"ArrowHistoricalAdapter",
219+
ArrowHistoricalAdapterImpl,
220+
ts[List[pa.RecordBatch]],
221+
ts_col_name=str,
222+
stream=Optional[Iterable[pa.RecordBatch]],
223+
tables=Optional[Iterable[pa.Table]],
224+
filenames=Optional[Iterable[str]],
225+
)
226+
227+
228+
@csp.node
229+
def accumulate_record_batches(filename: str, merge_record_batches: bool, batches: csp.ts[List[pa.RecordBatch]]):
230+
"""
231+
Dump all the record batches to a parquet file
232+
233+
Args:
234+
filename: name of file to write the data to
235+
merge_record_batches: A flag to combine all the record batches of a single tick into a single record batch (can save some space at the cost of memory)
236+
batches: The timeseries of list of record batches
237+
"""
238+
with csp.state():
239+
s_writer = None
240+
s_filename = filename
241+
s_merge_batches = merge_record_batches
242+
243+
with csp.stop():
244+
s_writer.close()
245+
246+
if csp.ticked(batches):
247+
if s_merge_batches:
248+
batches = [pa.concat_batches(batches)]
249+
250+
for batch in batches:
251+
if s_writer is None:
252+
s_writer = pq.ParquetWriter(s_filename, batch.schema)
253+
s_writer.write_batch(batch)

0 commit comments

Comments
 (0)