|
1 | | -""" |
2 | | -ann-benchmarks interfaces for Elasticsearch. |
3 | | -Note that this requires X-Pack, which is not included in the OSS version of Elasticsearch. |
4 | | -""" |
5 | | -import logging |
6 | 1 | from time import sleep |
7 | | -from urllib.error import URLError |
8 | | -from urllib.request import Request, urlopen |
9 | 2 |
|
10 | | -from elasticsearch import Elasticsearch |
| 3 | +from elasticsearch import ConnectionError, Elasticsearch |
11 | 4 | from elasticsearch.helpers import bulk |
12 | 5 |
|
13 | 6 | from .base import BaseANN |
14 | 7 |
|
15 | | -# Configure the elasticsearch logger. |
16 | | -# By default, it writes an INFO statement for every request. |
17 | | -logging.getLogger("elasticsearch").setLevel(logging.WARN) |
18 | 8 |
|
19 | | -# Uncomment these lines if you want to see timing for every HTTP request and its duration. |
20 | | -# logging.basicConfig(level=logging.INFO) |
21 | | -# logging.getLogger("elasticsearch").setLevel(logging.INFO) |
| 9 | +class ElasticsearchKNN(BaseANN): |
| 10 | + """Elasticsearch KNN search. |
22 | 11 |
|
23 | | - |
24 | | -def es_wait(): |
25 | | - print("Waiting for elasticsearch health endpoint...") |
26 | | - req = Request("http://localhost:9200/_cluster/health?wait_for_status=yellow&timeout=1s") |
27 | | - for i in range(30): |
28 | | - try: |
29 | | - res = urlopen(req) |
30 | | - if res.getcode() == 200: |
31 | | - print("Elasticsearch is ready") |
32 | | - return |
33 | | - except URLError: |
34 | | - pass |
35 | | - sleep(1) |
36 | | - raise RuntimeError("Failed to connect to local elasticsearch") |
37 | | - |
38 | | - |
39 | | -class ElasticsearchScriptScoreQuery(BaseANN): |
40 | | - """ |
41 | | - KNN using the Elasticsearch dense_vector datatype and script score functions. |
42 | | - - Dense vector field type: https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html |
43 | | - - Dense vector queries: https://www.elastic.co/guide/en/elasticsearch/reference/master/query-dsl-script-score-query.html |
| 12 | + See https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html for more details. |
44 | 13 | """ |
45 | 14 |
|
46 | | - def __init__(self, metric: str, dimension: int): |
47 | | - self.name = f"elasticsearch-script-score-query_metric={metric}_dimension={dimension}" |
| 15 | + def __init__(self, metric: str, dimension: int, index_options: dict): |
48 | 16 | self.metric = metric |
49 | 17 | self.dimension = dimension |
50 | | - self.index = f"es-ssq-{metric}-{dimension}" |
51 | | - self.es = Elasticsearch(["http://localhost:9200"]) |
| 18 | + self.index_options = index_options |
| 19 | + self.num_candidates = 100 |
| 20 | + |
| 21 | + index_options_str = "-".join(sorted(f"{k}-{v}" for k, v in self.index_options.items())) |
| 22 | + self.name = f"es-{metric}-{dimension}-{index_options_str}" |
| 23 | + self.similarity_metric = self._vector_similarity_metric(metric) |
| 24 | + |
| 25 | + self.client = Elasticsearch(["http://localhost:9200"]) |
52 | 26 | self.batch_res = [] |
53 | | - if self.metric == "euclidean": |
54 | | - self.script = '1 / (1 + l2norm(params.query_vec, "vec"))' |
55 | | - elif self.metric == "angular": |
56 | | - self.script = '1.0 + cosineSimilarity(params.query_vec, "vec")' |
57 | | - else: |
58 | | - raise NotImplementedError(f"Not implemented for metric {self.metric}") |
59 | | - es_wait() |
| 27 | + self._wait_for_health_status() |
| 28 | + |
| 29 | + def _vector_similarity_metric(self, metric: str): |
| 30 | + # `dot_product` is more efficient than `cosine`, but requires all vectors to be normalized |
| 31 | + # to unit length. We opt for adaptability, some datasets might not be normalized. |
| 32 | + supported_metrics = { |
| 33 | + "angular": "cosine", |
| 34 | + "euclidean": "l2_norm", |
| 35 | + } |
| 36 | + if metric not in supported_metrics: |
| 37 | + raise NotImplementedError(f"{metric} is not implemented") |
| 38 | + return supported_metrics[metric] |
| 39 | + |
| 40 | + def _wait_for_health_status(self, wait_seconds=30, status="yellow"): |
| 41 | + print("Waiting for Elasticsearch ...") |
| 42 | + for _ in range(wait_seconds): |
| 43 | + try: |
| 44 | + health = self.client.cluster.health(wait_for_status=status, request_timeout=1) |
| 45 | + print(f'Elasticsearch is ready: status={health["status"]}') |
| 46 | + return |
| 47 | + except ConnectionError: |
| 48 | + pass |
| 49 | + sleep(1) |
| 50 | + raise RuntimeError("Failed to connect to Elasticsearch") |
60 | 51 |
|
61 | 52 | def fit(self, X): |
62 | | - body = dict(settings=dict(number_of_shards=1, number_of_replicas=0)) |
63 | | - mapping = dict( |
64 | | - properties=dict(id=dict(type="keyword", store=True), vec=dict(type="dense_vector", dims=self.dimension)) |
65 | | - ) |
66 | | - self.es.indices.create(self.index, body=body) |
67 | | - self.es.indices.put_mapping(mapping, self.index) |
| 53 | + settings = { |
| 54 | + "number_of_shards": 1, |
| 55 | + "number_of_replicas": 0, |
| 56 | + "refresh_interval": -1, |
| 57 | + } |
| 58 | + mappings = { |
| 59 | + "properties": { |
| 60 | + "id": {"type": "keyword", "store": True}, |
| 61 | + "vec": { |
| 62 | + "type": "dense_vector", |
| 63 | + "element_type": "float", |
| 64 | + "dims": self.dimension, |
| 65 | + "index": True, |
| 66 | + "similarity": self.similarity_metric, |
| 67 | + "index_options": { |
| 68 | + "type": self.index_options.get("type", "hnsw"), |
| 69 | + "m": self.index_options["m"], |
| 70 | + "ef_construction": self.index_options["ef_construction"], |
| 71 | + }, |
| 72 | + }, |
| 73 | + }, |
| 74 | + } |
| 75 | + self.client.indices.create(index=self.name, settings=settings, mappings=mappings) |
68 | 76 |
|
69 | 77 | def gen(): |
70 | 78 | for i, vec in enumerate(X): |
71 | | - yield {"_op_type": "index", "_index": self.index, "vec": vec.tolist(), "id": str(i + 1)} |
| 79 | + yield {"_op_type": "index", "_index": self.name, "id": str(i), "vec": vec.tolist()} |
72 | 80 |
|
73 | | - (_, errors) = bulk(self.es, gen(), chunk_size=500, max_retries=9) |
74 | | - assert len(errors) == 0, errors |
| 81 | + print("Indexing ...") |
| 82 | + (_, errors) = bulk(self.client, gen(), chunk_size=500, request_timeout=90) |
| 83 | + if len(errors) != 0: |
| 84 | + raise RuntimeError("Failed to index documents") |
75 | 85 |
|
76 | | - self.es.indices.refresh(self.index) |
77 | | - self.es.indices.forcemerge(self.index, max_num_segments=1) |
| 86 | + print("Force merge index ...") |
| 87 | + self.client.indices.forcemerge(index=self.name, max_num_segments=1, request_timeout=900) |
| 88 | + |
| 89 | + print("Refreshing index ...") |
| 90 | + self.client.indices.refresh(index=self.name, request_timeout=900) |
| 91 | + |
| 92 | + def set_query_arguments(self, num_candidates): |
| 93 | + self.num_candidates = num_candidates |
78 | 94 |
|
79 | 95 | def query(self, q, n): |
80 | | - body = dict( |
81 | | - query=dict( |
82 | | - script_score=dict( |
83 | | - query=dict(match_all=dict()), script=dict(source=self.script, params=dict(query_vec=q.tolist())) |
84 | | - ) |
85 | | - ) |
86 | | - ) |
87 | | - res = self.es.search( |
88 | | - index=self.index, |
| 96 | + if n > self.num_candidates: |
| 97 | + raise ValueError("n must be smaller than num_candidates") |
| 98 | + |
| 99 | + body = { |
| 100 | + "knn": { |
| 101 | + "field": "vec", |
| 102 | + "query_vector": q.tolist(), |
| 103 | + "k": n, |
| 104 | + "num_candidates": self.num_candidates, |
| 105 | + } |
| 106 | + } |
| 107 | + res = self.client.search( |
| 108 | + index=self.name, |
89 | 109 | body=body, |
90 | 110 | size=n, |
91 | 111 | _source=False, |
92 | 112 | docvalue_fields=["id"], |
93 | 113 | stored_fields="_none_", |
94 | 114 | filter_path=["hits.hits.fields.id"], |
| 115 | + request_timeout=10, |
95 | 116 | ) |
96 | | - return [int(h["fields"]["id"][0]) - 1 for h in res["hits"]["hits"]] |
| 117 | + return [int(h["fields"]["id"][0]) for h in res["hits"]["hits"]] |
97 | 118 |
|
98 | 119 | def batch_query(self, X, n): |
99 | 120 | self.batch_res = [self.query(q, n) for q in X] |
|
0 commit comments