Skip to content

Commit 005cc53

Browse files
author
Tobias Kopp
committed
[Benchmark] Refactor DuckDB Connector
1 parent c800d8a commit 005cc53

File tree

1 file changed

+63
-53
lines changed

1 file changed

+63
-53
lines changed

benchmark/database_connectors/duckdb.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
TMP_DB = 'tmp.duckdb'
10-
TMP_SQL_FILE = 'tmp.sql'
1110

1211
class DuckDB(Connector):
1312

@@ -39,15 +38,22 @@ def execute(self, n_runs, params: dict):
3938
for _ in range(n_runs):
4039
try:
4140
# Set up database
42-
self.generate_create_table_stmts(params['data'], with_scale_factors)
41+
create_tbl_stmts = self.generate_create_table_stmts(params['data'], with_scale_factors)
4342

4443

4544
# If tables contain scale factors, they have to be loaded separately for every case
4645
if (with_scale_factors or not bool(params.get('readonly'))):
47-
timeout = (DEFAULT_TIMEOUT + TIMEOUT_PER_CASE) * len(params['cases'])
46+
timeout = DEFAULT_TIMEOUT + TIMEOUT_PER_CASE
4847
# Write cases/queries to a file that will be passed to the command to execute
49-
statements = list()
50-
for case, query_stmt in params['cases'].items():
48+
for i in range(len(params['cases'])):
49+
case = list(params['cases'].keys())[i]
50+
query_stmt = list(params['cases'].values())[i]
51+
52+
statements = list()
53+
if i==0:
54+
# Also use the create table stmts in this case
55+
statements = create_tbl_stmts
56+
5157
# Create tables from tmp tables with scale factor
5258
for table_name, table in params['data'].items():
5359
statements.append(f"DELETE FROM {table_name};") # empty existing table
@@ -63,51 +69,54 @@ def execute(self, n_runs, params: dict):
6369
statements.append(query_stmt) # Actual query from this case
6470
statements.append(".timer off")
6571

66-
# Append statements to file
67-
with open(TMP_SQL_FILE, "a+") as tmp:
68-
for stmt in statements:
69-
tmp.write(stmt + "\n")
72+
combined_query = "\n".join(statements)
73+
74+
if self.verbose and not verbose_printed:
75+
verbose_printed = True
76+
tqdm.write(combined_query)
77+
78+
benchmark_info = f"{suite}/{benchmark}/{experiment} [{configname}]"
79+
try:
80+
time = self.run_query(combined_query, timeout, benchmark_info)[0]
81+
except ExperimentTimeoutExpired as ex:
82+
time = timeout
83+
84+
if case not in measurement_times.keys():
85+
measurement_times[case] = list()
86+
measurement_times[case].append(time)
7087

7188

7289

7390
# Otherwise, tables have to be created just once before the measurements (done above)
7491
else:
7592
timeout = DEFAULT_TIMEOUT + TIMEOUT_PER_CASE * len(params['cases'])
76-
# Write cases/queries to a file that will be passed to the command to execute
77-
with open(TMP_SQL_FILE, "a+") as tmp:
78-
tmp.write(".timer on\n")
79-
for case_query in params['cases'].values():
80-
tmp.write(case_query + '\n')
81-
tmp.write(".timer off\n")
8293

94+
statements = create_tbl_stmts
95+
statements.append(".timer on")
96+
for case_query in params['cases'].values():
97+
statements.append(case_query)
98+
statements.append(".timer off")
8399

84-
# Execute query file and collect measurement data
85-
command = f"./{self.duckdb_cli} {TMP_DB} < {TMP_SQL_FILE}" + " | grep 'Run Time' | cut -d ' ' -f 5 | awk '{print $1 * 1000;}'"
86-
if not self.multithreaded:
87-
command = 'taskset -c 2 ' + command
100+
combined_query = "\n".join(statements)
88101

89-
if self.verbose:
90-
tqdm.write(f" $ {command}")
91-
if not verbose_printed:
102+
if self.verbose and not verbose_printed:
92103
verbose_printed = True
93-
with open(TMP_SQL_FILE) as tmp:
94-
tqdm.write(" " + " ".join(tmp.readlines()))
95-
96-
benchmark_info = f"{suite}/{benchmark}/{experiment} [{configname}]"
97-
try:
98-
durations = self.run_command(command, timeout, benchmark_info)
99-
except ExperimentTimeoutExpired as ex:
100-
for case in params['cases'].keys():
101-
if case not in measurement_times.keys():
102-
measurement_times[case] = list()
103-
measurement_times[case].append(TIMEOUT_PER_CASE * 1000)
104-
else:
105-
for idx, line in enumerate(durations):
106-
time = float(line.replace("\n", "").replace(",", ".")) # in milliseconds
107-
case = list(params['cases'].keys())[idx]
108-
if case not in measurement_times.keys():
109-
measurement_times[case] = list()
110-
measurement_times[case].append(time)
104+
tqdm.write(combined_query)
105+
106+
benchmark_info = f"{suite}/{benchmark}/{experiment} [{configname}]"
107+
try:
108+
durations = self.run_query(combined_query, timeout, benchmark_info)
109+
except ExperimentTimeoutExpired as ex:
110+
for case in params['cases'].keys():
111+
if case not in measurement_times.keys():
112+
measurement_times[case] = list()
113+
measurement_times[case].append(timeout * 1000)
114+
else:
115+
for idx, time in enumerate(durations):
116+
case = list(params['cases'].keys())[idx]
117+
if case not in measurement_times.keys():
118+
measurement_times[case] = list()
119+
measurement_times[case].append(time)
111120

112121

113122
finally:
@@ -120,8 +129,6 @@ def execute(self, n_runs, params: dict):
120129
def clean_up(self):
121130
if os.path.exists(TMP_DB):
122131
os.remove(TMP_DB)
123-
if os.path.exists(TMP_SQL_FILE):
124-
os.remove(TMP_SQL_FILE)
125132

126133

127134
# Parse attributes of one table, return as string
@@ -187,29 +194,30 @@ def generate_create_table_stmts(self, data: dict, with_scale_factors):
187194
# Create actual table that will be used for experiment
188195
statements.append(f"CREATE TABLE {table_name[:-4]} {columns};")
189196

190-
with open(TMP_SQL_FILE, "w") as tmp:
191-
for stmt in statements:
192-
tmp.write(stmt + "\n")
197+
return statements
198+
193199

200+
def run_query(self, query, timeout, benchmark_info):
201+
command = f"./{self.duckdb_cli} {TMP_DB}"
202+
if not self.multithreaded:
203+
command = 'taskset -c 2 ' + command
194204

195-
def run_command(self, command, timeout, benchmark_info):
196205
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
197-
cwd=os.getcwd(), shell=True)
206+
cwd=os.getcwd(), shell=True, text=True)
198207
try:
199-
out, err = process.communicate("".encode('latin-1'), timeout=timeout)
208+
out, err = process.communicate(query, timeout=timeout)
200209
except subprocess.TimeoutExpired:
201210
process.kill()
211+
tqdm.write(f" ! Query \n'{query}'\n' timed out after {timeout} seconds")
202212
raise ExperimentTimeoutExpired(f'Query timed out after {timeout} seconds')
203213
finally:
204214
if process.poll() is None: # if process is still alive
205215
process.terminate() # try to shut down gracefully
206216
try:
207-
process.wait(timeout=5) # wait for process to terminate
217+
process.wait(timeout=1) # wait for process to terminate
208218
except subprocess.TimeoutExpired:
209219
process.kill() # kill if process did not terminate in time
210220

211-
out = out.decode('latin-1')
212-
err = err.decode('latin-1')
213221

214222
if process.returncode or len(err):
215223
outstr = '\n'.join(out.split('\n')[-20:])
@@ -227,7 +235,9 @@ def run_command(self, command, timeout, benchmark_info):
227235
raise ConnectorException(f'Benchmark failed with return code {process.returncode}.')
228236

229237
# Parse `out` for timings
230-
durations = out.split('\n')
238+
durations = os.popen(f"echo '{out}'" + " | grep 'Run Time' | cut -d ' ' -f 5 | awk '{print $1 * 1000;}'").read()
239+
durations = durations.split('\n')
231240
durations.remove('')
241+
durations = [float(i.replace("\n", "").replace(",", ".")) for i in durations]
232242

233-
return durations
243+
return durations

0 commit comments

Comments
 (0)