1- import logging
1+ from time import sleep
22from urllib .request import Request , urlopen
33
4- from elasticsearch import Elasticsearch
5- from elasticsearch .helpers import bulk
4+ from opensearchpy import ConnectionError , OpenSearch
5+ from opensearchpy .helpers import bulk
66from tqdm import tqdm
77
88from .base import BaseANN
9- from .elasticsearch import es_wait
10-
11- # Configure the logger.
12- logging .getLogger ("elasticsearch" ).setLevel (logging .WARN )
139
1410
1511class OpenSearchKNN (BaseANN ):
@@ -19,8 +15,19 @@ def __init__(self, metric, dimension, method_param):
1915 self .method_param = method_param
2016 self .param_string = "-" .join (k + "-" + str (v ) for k , v in self .method_param .items ()).lower ()
2117 self .name = f"os-{ self .param_string } "
22- self .es = Elasticsearch (["http://localhost:9200" ])
23- es_wait ()
18+ self .client = OpenSearch (["http://localhost:9200" ])
19+ self ._wait_for_health_status ()
20+
21+ def _wait_for_health_status (self , wait_seconds = 30 , status = "yellow" ):
22+ for _ in range (wait_seconds ):
23+ try :
24+ self .client .cluster .health (wait_for_status = status )
25+ return
26+ except ConnectionError as e :
27+ pass
28+ sleep (1 )
29+
30+ raise RuntimeError ("Failed to connect to OpenSearch" )
2431
2532 def fit (self , X ):
2633 body = {
@@ -46,36 +53,36 @@ def fit(self, X):
4653 }
4754 }
4855
49- self .es .indices .create (self .name , body = body )
50- self .es .indices .put_mapping (mapping , self .name )
56+ self .client .indices .create (self .name , body = body )
57+ self .client .indices .put_mapping (mapping , self .name )
5158
5259 print ("Uploading data to the Index:" , self .name )
5360
5461 def gen ():
5562 for i , vec in enumerate (tqdm (X )):
5663 yield {"_op_type" : "index" , "_index" : self .name , "vec" : vec .tolist (), "id" : str (i + 1 )}
5764
58- (_ , errors ) = bulk (self .es , gen (), chunk_size = 500 , max_retries = 2 , request_timeout = 10 )
65+ (_ , errors ) = bulk (self .client , gen (), chunk_size = 500 , max_retries = 2 , request_timeout = 10 )
5966 assert len (errors ) == 0 , errors
6067
6168 print ("Force Merge..." )
62- self .es .indices .forcemerge (self .name , max_num_segments = 1 , request_timeout = 1000 )
69+ self .client .indices .forcemerge (self .name , max_num_segments = 1 , request_timeout = 1000 )
6370
6471 print ("Refreshing the Index..." )
65- self .es .indices .refresh (self .name , request_timeout = 1000 )
72+ self .client .indices .refresh (self .name , request_timeout = 1000 )
6673
6774 print ("Running Warmup API..." )
6875 res = urlopen (Request ("http://localhost:9200/_plugins/_knn/warmup/" + self .name + "?pretty" ))
6976 print (res .read ().decode ("utf-8" ))
7077
7178 def set_query_arguments (self , ef ):
7279 body = {"settings" : {"index" : {"knn.algo_param.ef_search" : ef }}}
73- self .es .indices .put_settings (body = body )
80+ self .client .indices .put_settings (body = body )
7481
7582 def query (self , q , n ):
7683 body = {"query" : {"knn" : {"vec" : {"vector" : q .tolist (), "k" : n }}}}
7784
78- res = self .es .search (
85+ res = self .client .search (
7986 index = self .name ,
8087 body = body ,
8188 size = n ,
@@ -95,4 +102,4 @@ def get_batch_results(self):
95102 return self .batch_res
96103
97104 def freeIndex (self ):
98- self .es .indices .delete (index = self .name )
105+ self .client .indices .delete (index = self .name )
0 commit comments