Skip to content

Commit 2beb3e3

Browse files
aijamysutedja
authored andcommitted
Keep chunk requests and responses synchronized
1 parent 6b51317 commit 2beb3e3

File tree

3 files changed

+61
-74
lines changed

3 files changed

+61
-74
lines changed

splunklib/searchcommands/generating_command.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,16 @@ def generate(self):
194194
"""
195195
raise NotImplementedError('GeneratingCommand.generate(self)')
196196

197+
198+
def __generate_chunk(self, unused_input):
199+
count = 0
200+
for row in generate():
201+
yield row
202+
count += 1
203+
if count == self._record_writer._maxresultrows:
204+
# count = 0
205+
return
206+
197207
def _execute(self, ifile, process):
198208
""" Execution loop
199209
@@ -204,18 +214,10 @@ def _execute(self, ifile, process):
204214
205215
"""
206216
if self._protocol_version == 2:
207-
result = self._read_chunk(ifile)
208-
209-
if not result:
210-
return
211-
212-
metadata, body = result
213-
action = getattr(metadata, 'action', None)
214-
215-
if action != 'execute':
216-
raise RuntimeError('Expected execute action, not {}'.format(action))
217-
218-
self._record_writer.write_records(self.generate())
217+
self._execute_v2(ifile, self.__generate_chunk)
218+
else:
219+
assert self._protocol_version == 1
220+
self._record_writer.write_records(self.generate())
219221
self.finish()
220222

221223
# endregion

splunklib/searchcommands/internals.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ def _write_record(self, record):
654654
self._record_count += 1
655655

656656
if self._record_count >= self._maxresultrows:
657+
657658
self.flush(partial=True)
658659

659660
try:
@@ -748,42 +749,39 @@ class RecordWriterV2(RecordWriter):
748749
def flush(self, finished=None, partial=None):
749750

750751
RecordWriter.flush(self, finished, partial) # validates arguments and the state of this instance
751-
inspector = self._inspector
752752

753-
if partial:
753+
if partial or not finished:
754754
# Don't flush partial chunks, since the SCP v2 protocol does not
755755
# provide a way to send partial chunks yet.
756756
return
757757

758-
if self._flushed is False:
759-
760-
self._total_record_count += self._record_count
761-
self._chunk_count += 1
762-
763-
# TODO: DVPL-6448: splunklib.searchcommands | Add support for partial: true when it is implemented in
764-
# ChunkedExternProcessor (See SPL-103525)
765-
#
766-
# We will need to replace the following block of code with this block:
767-
#
768-
# metadata = [
769-
# ('inspector', self._inspector if len(self._inspector) else None),
770-
# ('finished', finished),
771-
# ('partial', partial)]
772-
773-
if len(inspector) == 0:
774-
inspector = None
758+
#if finished is True:
759+
self.write_chunk(finished=True)
775760

776-
if partial is True:
777-
finished = False
778-
779-
metadata = [item for item in (('inspector', inspector), ('finished', finished))]
780-
self._write_chunk(metadata, self._buffer.getvalue())
781-
self._clear()
782-
783-
elif finished is True:
784-
self._write_chunk((('finished', True),), '')
785-
786-
self._finished = finished is True
761+
def write_chunk(self, finished=None):
762+
inspector = self._inspector
763+
self._total_record_count += self._record_count
764+
self._chunk_count += 1
765+
766+
# TODO: DVPL-6448: splunklib.searchcommands | Add support for partial: true when it is implemented in
767+
# ChunkedExternProcessor (See SPL-103525)
768+
#
769+
# We will need to replace the following block of code with this block:
770+
#
771+
# metadata = [
772+
# ('inspector', self._inspector if len(self._inspector) else None),
773+
# ('finished', finished),
774+
# ('partial', partial)]
775+
776+
if len(inspector) == 0:
777+
inspector = None
778+
779+
#if partial is True:
780+
# finished = False
781+
782+
metadata = [item for item in (('inspector', inspector), ('finished', finished))]
783+
self._write_chunk(metadata, self._buffer.getvalue())
784+
self._clear()
787785

788786
def write_metadata(self, configuration):
789787
self._ensure_validity()

splunklib/searchcommands/search_command.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _process_protocol_v2(self, argv, ifile, ofile):
776776
# noinspection PyBroadException
777777
try:
778778
debug('Executing under protocol_version=2')
779-
self._records = self._records_protocol_v2
779+
#self._records = self._records_protocol_v2
780780
self._metadata.action = 'execute'
781781
self._execute(ifile, None)
782782
except SystemExit:
@@ -833,6 +833,8 @@ def _decode_list(mv):
833833

834834
_encoded_value = re.compile(r'\$(?P<item>(?:\$\$|[^$])*)\$(?:;|$)') # matches a single value in an encoded list
835835

836+
# Note: Subclasses must override this method so that it can be called
837+
# called as self._execute(ifile, None)
836838
def _execute(self, ifile, process):
837839
""" Default processing loop
838840
@@ -846,8 +848,12 @@ def _execute(self, ifile, process):
846848
:rtype: NoneType
847849
848850
"""
849-
self._record_writer.write_records(process(self._records(ifile)))
850-
self.finish()
851+
if self.protocol_version == 1:
852+
self._record_writer.write_records(process(self._records(ifile)))
853+
self.finish()
854+
else:
855+
assert self._protocol_version == 2
856+
self._execute_v2(ifile, process)
851857

852858
@staticmethod
853859
def _read_chunk(ifile):
@@ -896,7 +902,9 @@ def _read_chunk(ifile):
896902
_header = re.compile(r'chunked\s+1.0\s*,\s*(\d+)\s*,\s*(\d+)\s*\n')
897903

898904
def _records_protocol_v1(self, ifile):
905+
return self._read_csv_records(ifile)
899906

907+
def _read_csv_records(self, ifile):
900908
reader = csv.reader(ifile, dialect=CsvDialect)
901909

902910
try:
@@ -921,7 +929,7 @@ def _records_protocol_v1(self, ifile):
921929
record[fieldname] = value
922930
yield record
923931

924-
def _records_protocol_v2(self, ifile):
932+
def _execute_v2(self, ifile, process):
925933

926934
while True:
927935
result = self._read_chunk(ifile)
@@ -931,41 +939,20 @@ def _records_protocol_v2(self, ifile):
931939

932940
metadata, body = result
933941
action = getattr(metadata, 'action', None)
934-
935942
if action != 'execute':
936943
raise RuntimeError('Expected execute action, not {}'.format(action))
937-
938-
finished = getattr(metadata, 'finished', False)
939944
self._record_writer.is_flushed = False
940945

941-
if len(body) > 0:
942-
reader = csv.reader(StringIO(body), dialect=CsvDialect)
946+
self._execute_chunk_v2(process, result)
943947

944-
try:
945-
fieldnames = next(reader)
946-
except StopIteration:
947-
return
948+
self._record_writer.write_chunk()
948949

949-
mv_fieldnames = dict([(name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')])
950-
951-
if len(mv_fieldnames) == 0:
952-
for values in reader:
953-
yield OrderedDict(izip(fieldnames, values))
954-
else:
955-
for values in reader:
956-
record = OrderedDict()
957-
for fieldname, value in izip(fieldnames, values):
958-
if fieldname.startswith('__mv_'):
959-
if len(value) > 0:
960-
record[mv_fieldnames[fieldname]] = self._decode_list(value)
961-
elif fieldname not in record:
962-
record[fieldname] = value
963-
yield record
964-
965-
if finished:
966-
return
950+
def _execute_chunk_v2(self, process, chunk):
951+
metadata, body = chunk
952+
if len(body) > 0:
953+
records = self._read_csv_records(StringIO(body))
954+
self._record_writer.write_records(process(records))
967955

968-
self.flush()
969956

970957
def _report_unexpected_error(self):
971958

0 commit comments

Comments
 (0)