1+ import base64
12import json
3+ import time
24from typing import Any , Dict , Optional , Tuple
35
46import pyarrow .flight as flight
57from pandas .core .frame import DataFrame
8+ from pyarrow .flight import ClientMiddleware , ClientMiddlewareFactory
69
710from .arrow_graph_constructor import ArrowGraphConstructor
811from .graph_constructor import GraphConstructor
@@ -28,16 +31,15 @@ def __init__(
2831 else flight .Location .for_grpc_tcp (host , int (port_string ))
2932 )
3033
31- self ._flight_client = flight .FlightClient (location , disable_server_verification = disable_server_verification )
32- self ._flight_options = flight .FlightCallOptions ()
33-
34+ client_options : Dict [str , Any ] = {"disable_server_verification" : disable_server_verification }
3435 if auth :
35- username , password = auth
36- header , token = self ._flight_client .authenticate_basic_token (username , password )
37- if header :
38- self ._flight_options = flight .FlightCallOptions (headers = [(header , token )])
36+ client_options ["middleware" ] = [AuthFactory (auth )]
37+
38+ self ._flight_client = flight .FlightClient (location , ** client_options )
3939
40- def run_query (self , query : str , params : Dict [str , Any ] = {}) -> DataFrame :
40+ def run_query (self , query : str , params : Optional [Dict [str , Any ]] = None ) -> DataFrame :
41+ if params is None :
42+ params = {}
4143 if "gds.graph.streamNodeProperty" in query :
4244 graph_name = params ["graph_name" ]
4345 property_name = params ["properties" ]
@@ -57,8 +59,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
5759
5860 return self ._fallback_query_runner .run_query (query , params )
5961
60- def run_query_with_logging (self , query : str , params : Dict [str , Any ] = {} ) -> DataFrame :
62+ def run_query_with_logging (self , query : str , params : Optional [ Dict [str , Any ]] = None ) -> DataFrame :
6163 # For now there's no logging support with Arrow queries.
64+ if params is None :
65+ params = {}
6266 return self ._fallback_query_runner .run_query_with_logging (query , params )
6367
6468 def set_database (self , db : str ) -> None :
@@ -79,9 +83,61 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
7983 }
8084 ticket = flight .Ticket (json .dumps (payload ).encode ("utf-8" ))
8185
82- result : DataFrame = self ._flight_client .do_get (ticket , self ._flight_options ).read_pandas ()
86+ get = self ._flight_client .do_get (ticket )
87+ result : DataFrame = get .read_pandas ()
8388
8489 return result
8590
8691 def create_graph_constructor (self , graph_name : str , concurrency : int ) -> GraphConstructor :
87- return ArrowGraphConstructor (self , graph_name , self ._flight_client , self ._flight_options , concurrency )
92+ return ArrowGraphConstructor (self , graph_name , self ._flight_client , concurrency )
93+
94+
95+ class AuthFactory (ClientMiddlewareFactory ): # type: ignore
96+ def __init__ (self , auth : Tuple [str , str ], * args : Any , ** kwargs : Any ) -> None :
97+ super ().__init__ (* args , ** kwargs )
98+ self ._auth = auth
99+ self ._token : Optional [str ] = None
100+ self ._token_timestamp = 0
101+
102+ def start_call (self , info : Any ) -> "AuthMiddleware" :
103+ return AuthMiddleware (self )
104+
105+ def token (self ) -> Optional [str ]:
106+ # check whether the token is older than 10 minutes. If so, reset it.
107+ if self ._token and int (time .time ()) - self ._token_timestamp > 600 :
108+ self ._token = None
109+
110+ return self ._token
111+
112+ def set_token (self , token : str ) -> None :
113+ self ._token = token
114+ self ._token_timestamp = int (time .time ())
115+
116+ @property
117+ def auth (self ) -> Tuple [str , str ]:
118+ return self ._auth
119+
120+
121+ class AuthMiddleware (ClientMiddleware ): # type: ignore
122+ def __init__ (self , factory : AuthFactory , * args : Any , ** kwargs : Any ) -> None :
123+ super ().__init__ (* args , ** kwargs )
124+ self ._factory = factory
125+
126+ def received_headers (self , headers : Dict [str , Any ]) -> None :
127+ auth_header : str = headers .get ("Authorization" , None )
128+ if not auth_header :
129+ return
130+ [auth_type , token ] = auth_header .split (" " , 1 )
131+ if auth_type == "Bearer" :
132+ self ._factory .set_token (token )
133+
134+ def sending_headers (self ) -> Dict [str , str ]:
135+ token = self ._factory .token ()
136+ if not token :
137+ username , password = self ._factory .auth
138+ auth_token = f"{ username } :{ password } "
139+ auth_token = "Basic " + base64 .b64encode (auth_token .encode ("utf-8" )).decode ("ASCII" )
140+ # There seems to be a bug, `authorization` must be lower key
141+ return {"authorization" : auth_token }
142+ else :
143+ return {"authorization" : "Bearer " + token }
0 commit comments