Skip to content

Commit 5bdb921

Browse files
benchmark script
1 parent 453e383 commit 5bdb921

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

utils/dense-vector-benchmark.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import argparse
19+
import asyncio
20+
import json
21+
import os
22+
import time
23+
24+
import numpy as np
25+
26+
from elasticsearch import OrjsonSerializer
27+
from elasticsearch.dsl import AsyncDocument, NumpyDenseVector, async_connections
28+
from elasticsearch.dsl.types import DenseVectorIndexOptions
29+
from elasticsearch.helpers import async_bulk, pack_dense_vector
30+
31+
async_connections.create_connection(
32+
hosts=[os.environ["ELASTICSEARCH_URL"]], serializer=OrjsonSerializer()
33+
)
34+
35+
36+
class Doc(AsyncDocument):
37+
title: str
38+
text: str
39+
emb: np.ndarray = NumpyDenseVector(
40+
dtype=np.float32, index_options=DenseVectorIndexOptions(type="flat")
41+
)
42+
43+
class Index:
44+
name = "benchmark"
45+
46+
47+
async def upload(data_file: str, chunk_size: int, pack: bool) -> tuple[float, float]:
48+
with open(data_file, "rt") as f:
49+
# read the data file, which comes in ndjson format and convert it to JSON
50+
json_data = "[" + f.read().strip().replace("\n", ",") + "]"
51+
dataset = json.loads(json_data)
52+
53+
# replace the embedding lists with numpy arrays for performance
54+
dataset = [
55+
{
56+
"docid": doc["docid"],
57+
"title": doc["title"],
58+
"text": doc["text"],
59+
"emb": np.array(doc["emb"], dtype=np.float32),
60+
}
61+
for doc in dataset
62+
]
63+
64+
# create mapping and index
65+
if await Doc._index.exists():
66+
await Doc._index.delete()
67+
await Doc.init()
68+
await Doc._index.refresh()
69+
70+
async def get_next_document():
71+
for doc in dataset:
72+
yield {
73+
"_index": "benchmark",
74+
"_id": doc["docid"],
75+
"_source": {
76+
"title": doc["title"],
77+
"text": doc["text"],
78+
"emb": doc["emb"],
79+
},
80+
}
81+
82+
async def get_next_document_packed():
83+
for doc in dataset:
84+
yield {
85+
"_index": "benchmark",
86+
"_id": doc["docid"],
87+
"_source": {
88+
"title": doc["title"],
89+
"text": doc["text"],
90+
"emb": pack_dense_vector(doc["emb"]),
91+
},
92+
}
93+
94+
start = time.time()
95+
result = await async_bulk(
96+
client=async_connections.get_connection(),
97+
chunk_size=chunk_size,
98+
actions=get_next_document_packed() if pack else get_next_document(),
99+
stats_only=True,
100+
)
101+
duration = time.time() - start
102+
assert result[1] == 0
103+
return result[0], duration
104+
105+
106+
async def main():
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument("data_file", metavar="JSON_DATA_FILE")
109+
parser.add_argument(
110+
"--chunk-sizes", "-s", nargs="+", help="Chunk size(s) for bulk uploader"
111+
)
112+
args = parser.parse_args()
113+
114+
for chunk_size in args.chunk_sizes:
115+
print(f"Uploading '{args.data_file}' with chunk size {chunk_size}...")
116+
runs = []
117+
packed_runs = []
118+
for _ in range(3):
119+
runs.append(await upload(args.data_file, chunk_size, False))
120+
packed_runs.append(await upload(args.data_file, chunk_size, True))
121+
122+
# ensure that all runs uploaded the same number of documents
123+
size = runs[0][0]
124+
for run in runs:
125+
assert run[0] == size
126+
for run in packed_runs:
127+
assert run[0] == size
128+
129+
dur = sum([run[1] for run in runs]) / len(runs)
130+
packed_dur = sum([run[1] for run in packed_runs]) / len(packed_runs)
131+
132+
print(f"Size: {size}")
133+
print(f"float duration: {dur:.02f}s / {size / dur:.02f} docs/s")
134+
print(
135+
f"float base64 duration: {packed_dur:.02f}s / {size / packed_dur:.02f} docs/s"
136+
)
137+
print(f"Speed up: {dur / packed_dur:.02f}x")
138+
139+
140+
if __name__ == "__main__":
141+
asyncio.run(main())

0 commit comments

Comments
 (0)