2222class Client (ABC ):
2323 """An abstract base class for all clients."""
2424
25+ name : str
26+ """The name of this client."""
27+
2528 @classmethod
2629 async def from_config (cls : type [T ], config : Config ) -> T :
2730 """Creates a client using the provided configuration.
@@ -184,7 +187,7 @@ class Clients:
184187 """An async-safe cache of clients."""
185188
186189 lock : Lock
187- clients : dict [type [ Client ] , Client ]
190+ clients : dict [str , Client ]
188191 config : Config
189192
190193 def __init__ (self , config : Config , clients : list [Client ] | None = None ) -> None :
@@ -193,7 +196,7 @@ def __init__(self, config: Config, clients: list[Client] | None = None) -> None:
193196 if clients :
194197 # TODO check for duplicate types in clients list
195198 for client in clients :
196- self .clients [type ( client ) ] = client
199+ self .clients [client . name ] = client
197200 self .config = config
198201
199202 async def get_client (self , href : str ) -> Client :
@@ -205,15 +208,21 @@ async def get_client(self, href: str) -> Client:
205208 Returns:
206209 Client: An instance of that client.
207210 """
211+ # TODO allow dynamic registration of new clients, e.g. via a plugin mechanism
212+
208213 from .earthdata_client import EarthdataClient
209214 from .filesystem_client import FilesystemClient
210215 from .http_client import HttpClient
211216 from .planetary_computer_client import PlanetaryComputerClient
212217 from .s3_client import S3Client
213218
214219 url = URL (href )
215- if not url .host :
216- client_class : type [Client ] = FilesystemClient
220+ if self .config .client_override :
221+ client_class : type [Client ] = _get_client_class_by_name (
222+ self .config .client_override
223+ )
224+ elif not url .host :
225+ client_class = FilesystemClient
217226 elif url .scheme == "s3" :
218227 client_class = S3Client
219228 elif url .host .endswith ("blob.core.windows.net" ):
@@ -226,15 +235,34 @@ async def get_client(self, href: str) -> Client:
226235 raise ValueError (f"could not guess client class for href: { href } " )
227236
228237 async with self .lock :
229- if client_class in self .clients :
230- return self .clients [client_class ]
238+ if client_class . name in self .clients :
239+ return self .clients [client_class . name ]
231240 else :
232241 client = await client_class .from_config (self .config )
233- self .clients [client_class ] = client
242+ self .clients [client_class . name ] = client
234243 return client
235244
236245 async def close_all (self ) -> None :
237246 """Close all clients."""
238247 async with self .lock :
239248 for client in self .clients .values ():
240249 await client .close ()
250+
251+
252+ def _get_client_class_by_name (name : str ) -> type [Client ]:
253+ for client_class in get_client_classes ():
254+ if client_class .name == name :
255+ return client_class
256+ raise ValueError (f"no client with name: { name } " )
257+
258+
259+ def get_client_classes () -> list [type [Client ]]:
260+ """Returns a list of all known subclasses of Client."""
261+
262+ # https://stackoverflow.com/questions/3862310/how-to-find-all-the-subclasses-of-a-class-given-its-name
263+ def all_subclasses (cls : type [Client ]) -> set [type [Client ]]:
264+ return set (cls .__subclasses__ ()).union (
265+ [s for c in cls .__subclasses__ () for s in all_subclasses (c )]
266+ )
267+
268+ return list (all_subclasses (Client )) # type: ignore
0 commit comments