|
22 | 22 | from dataclasses import dataclass
|
23 | 23 | from typing import Dict, List, Optional
|
24 | 24 |
|
25 |
| -import httpx |
| 25 | +import aiohttp |
26 | 26 | import requests
|
27 | 27 | from kubernetes import client, config, watch
|
28 | 28 |
|
@@ -308,22 +308,29 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
|
308 | 308 | model_info=self._get_model_info(model),
|
309 | 309 | )
|
310 | 310 | endpoint_infos.append(endpoint_info)
|
| 311 | + return endpoint_infos |
| 312 | + |
| 313 | + async def initialize_client_sessions(self) -> None: |
| 314 | + """ |
| 315 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 316 | + This must be called from an async context during app startup. |
| 317 | + """ |
311 | 318 | if (
|
312 | 319 | self.prefill_model_labels is not None
|
313 | 320 | and self.decode_model_labels is not None
|
314 | 321 | ):
|
| 322 | + endpoint_infos = self.get_endpoint_info() |
315 | 323 | for endpoint_info in endpoint_infos:
|
316 | 324 | if endpoint_info.model_label in self.prefill_model_labels:
|
317 |
| - self.app.state.prefill_client = httpx.AsyncClient( |
| 325 | + self.app.state.prefill_client = aiohttp.ClientSession( |
318 | 326 | base_url=endpoint_info.url,
|
319 |
| - timeout=None, |
| 327 | + timeout=aiohttp.ClientTimeout(total=None), |
320 | 328 | )
|
321 | 329 | elif endpoint_info.model_label in self.decode_model_labels:
|
322 |
| - self.app.state.decode_client = httpx.AsyncClient( |
| 330 | + self.app.state.decode_client = aiohttp.ClientSession( |
323 | 331 | base_url=endpoint_info.url,
|
324 |
| - timeout=None, |
| 332 | + timeout=aiohttp.ClientTimeout(total=None), |
325 | 333 | )
|
326 |
| - return endpoint_infos |
327 | 334 |
|
328 | 335 |
|
329 | 336 | class K8sPodIPServiceDiscovery(ServiceDiscovery):
|
@@ -629,20 +636,7 @@ def _add_engine(
|
629 | 636 | namespace=self.namespace,
|
630 | 637 | model_info=model_info,
|
631 | 638 | )
|
632 |
| - if ( |
633 |
| - self.prefill_model_labels is not None |
634 |
| - and self.decode_model_labels is not None |
635 |
| - ): |
636 |
| - if model_label in self.prefill_model_labels: |
637 |
| - self.app.state.prefill_client = httpx.AsyncClient( |
638 |
| - base_url=f"http://{engine_ip}:{self.port}", |
639 |
| - timeout=None, |
640 |
| - ) |
641 |
| - elif model_label in self.decode_model_labels: |
642 |
| - self.app.state.decode_client = httpx.AsyncClient( |
643 |
| - base_url=f"http://{engine_ip}:{self.port}", |
644 |
| - timeout=None, |
645 |
| - ) |
| 639 | + |
646 | 640 | # Store model information in the endpoint info
|
647 | 641 | self.available_engines[engine_name].model_info = model_info
|
648 | 642 |
|
@@ -720,6 +714,28 @@ def close(self):
|
720 | 714 | self.k8s_watcher.stop()
|
721 | 715 | self.watcher_thread.join()
|
722 | 716 |
|
| 717 | + async def initialize_client_sessions(self) -> None: |
| 718 | + """ |
| 719 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 720 | + This must be called from an async context during app startup. |
| 721 | + """ |
| 722 | + if ( |
| 723 | + self.prefill_model_labels is not None |
| 724 | + and self.decode_model_labels is not None |
| 725 | + ): |
| 726 | + endpoint_infos = self.get_endpoint_info() |
| 727 | + for endpoint_info in endpoint_infos: |
| 728 | + if endpoint_info.model_label in self.prefill_model_labels: |
| 729 | + self.app.state.prefill_client = aiohttp.ClientSession( |
| 730 | + base_url=endpoint_info.url, |
| 731 | + timeout=aiohttp.ClientTimeout(total=None), |
| 732 | + ) |
| 733 | + elif endpoint_info.model_label in self.decode_model_labels: |
| 734 | + self.app.state.decode_client = aiohttp.ClientSession( |
| 735 | + base_url=endpoint_info.url, |
| 736 | + timeout=aiohttp.ClientTimeout(total=None), |
| 737 | + ) |
| 738 | + |
723 | 739 |
|
724 | 740 | class K8sServiceNameServiceDiscovery(ServiceDiscovery):
|
725 | 741 | def __init__(
|
@@ -1024,20 +1040,7 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str
|
1024 | 1040 | namespace=self.namespace,
|
1025 | 1041 | model_info=model_info,
|
1026 | 1042 | )
|
1027 |
| - if ( |
1028 |
| - self.prefill_model_labels is not None |
1029 |
| - and self.decode_model_labels is not None |
1030 |
| - ): |
1031 |
| - if model_label in self.prefill_model_labels: |
1032 |
| - self.app.state.prefill_client = httpx.AsyncClient( |
1033 |
| - base_url=f"http://{engine_name}:{self.port}", |
1034 |
| - timeout=None, |
1035 |
| - ) |
1036 |
| - elif model_label in self.decode_model_labels: |
1037 |
| - self.app.state.decode_client = httpx.AsyncClient( |
1038 |
| - base_url=f"http://{engine_name}:{self.port}", |
1039 |
| - timeout=None, |
1040 |
| - ) |
| 1043 | + |
1041 | 1044 | # Store model information in the endpoint info
|
1042 | 1045 | self.available_engines[engine_name].model_info = model_info
|
1043 | 1046 |
|
@@ -1114,6 +1117,28 @@ def close(self):
|
1114 | 1117 | self.k8s_watcher.stop()
|
1115 | 1118 | self.watcher_thread.join()
|
1116 | 1119 |
|
| 1120 | + async def initialize_client_sessions(self) -> None: |
| 1121 | + """ |
| 1122 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 1123 | + This must be called from an async context during app startup. |
| 1124 | + """ |
| 1125 | + if ( |
| 1126 | + self.prefill_model_labels is not None |
| 1127 | + and self.decode_model_labels is not None |
| 1128 | + ): |
| 1129 | + endpoint_infos = self.get_endpoint_info() |
| 1130 | + for endpoint_info in endpoint_infos: |
| 1131 | + if endpoint_info.model_label in self.prefill_model_labels: |
| 1132 | + self.app.state.prefill_client = aiohttp.ClientSession( |
| 1133 | + base_url=endpoint_info.url, |
| 1134 | + timeout=aiohttp.ClientTimeout(total=None), |
| 1135 | + ) |
| 1136 | + elif endpoint_info.model_label in self.decode_model_labels: |
| 1137 | + self.app.state.decode_client = aiohttp.ClientSession( |
| 1138 | + base_url=endpoint_info.url, |
| 1139 | + timeout=aiohttp.ClientTimeout(total=None), |
| 1140 | + ) |
| 1141 | + |
1117 | 1142 |
|
1118 | 1143 | def _create_service_discovery(
|
1119 | 1144 | service_discovery_type: ServiceDiscoveryType, *args, **kwargs
|
|
0 commit comments