Skip to content

Commit bc9d2a7

Browse files
authored
Merge branch 'main' into python_3_13
2 parents 00304d0 + e65f09c commit bc9d2a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+6304
-546
lines changed

README-development.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
The Oracle Accelerated Data Science (ADS) SDK used by data scientists and analysts for
55
data exploration and experimental machine learning to democratize machine learning and
66
analytics by providing easy-to-use,
7-
performant, and user friendly tools that
7+
performant, and user-friendly tools that
88
brings together the best of data science practices.
99

1010
The ADS SDK helps you connect to different data sources, perform exploratory data analysis,
@@ -176,7 +176,7 @@ pip install -r test-requirements.txt
176176
```
177177

178178
### Step 2: Create local .env files
179-
Running the local JuypterLab server requires setting OCI authentication, proxy, and OCI namespace parameters. Adapt this .env file with your specific OCI profile and OCIDs to set these variables.
179+
Running the local JupyterLab server requires setting OCI authentication, proxy, and OCI namespace parameters. Adapt this .env file with your specific OCI profile and OCIDs to set these variables.
180180

181181
```
182182
CONDA_BUCKET_NS="your_conda_bucket"
@@ -248,7 +248,7 @@ All the unit tests can be found [here](https://github.com/oracle/accelerated-dat
248248
The following commands detail how the unit tests can be run.
249249
```
250250
# Run all tests in AQUA project
251-
python -m pytest -q tests/unitary/with_extras/aqua/test_deployment.py
251+
python -m pytest -q tests/unitary/with_extras/aqua/*
252252
253253
# Run all tests specific to a module within in AQUA project (ex. test_deployment.py, test_model.py, etc.)
254254
python -m pytest -q tests/unitary/with_extras/aqua/test_deployment.py

ads/aqua/app.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import os
77
import traceback
88
from dataclasses import fields
9+
from datetime import datetime, timedelta
910
from typing import Any, Dict, Optional, Union
1011

1112
import oci
13+
from cachetools import TTLCache, cached
1214
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1315

1416
from ads import set_auth
@@ -269,6 +271,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
269271
logger.info(f"Artifact not found in model {model_id}.")
270272
return False
271273

274+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
272275
def get_config(
273276
self,
274277
model_id: str,
@@ -337,6 +340,9 @@ def get_config(
337340
config_file_path = os.path.join(config_path, config_file_name)
338341
if is_path_exists(config_file_path):
339342
try:
343+
logger.debug(
344+
f"Loading config: `{config_file_name}` from `{config_path}`"
345+
)
340346
config = load_config(
341347
config_path,
342348
config_file_name=config_file_name,

ads/aqua/client/openai_client.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
import json
6+
import logging
7+
import re
8+
from typing import Any, Dict, Optional
9+
10+
import httpx
11+
from git import Union
12+
13+
from ads.aqua.client.client import get_async_httpx_client, get_httpx_client
14+
from ads.common.extended_enum import ExtendedEnum
15+
16+
logger = logging.getLogger(__name__)
17+
18+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0)
19+
DEFAULT_MAX_RETRIES = 2
20+
21+
22+
try:
23+
import openai
24+
except ImportError as e:
25+
raise ModuleNotFoundError(
26+
"The custom OpenAI client requires the `openai-python` package. "
27+
"Please install it with `pip install openai`."
28+
) from e
29+
30+
31+
class ModelDeploymentBaseEndpoint(ExtendedEnum):
32+
"""Supported base endpoints for model deployments."""
33+
34+
PREDICT = "predict"
35+
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
36+
37+
38+
class AquaOpenAIMixin:
39+
"""
40+
Mixin that provides common logic to patch HTTP request headers and URLs
41+
for compatibility with the OCI Model Deployment service using the OpenAI API schema.
42+
"""
43+
44+
def _patch_route(self, original_path: str) -> str:
45+
"""
46+
Extracts and formats the OpenAI-style route path from a full request path.
47+
48+
Args:
49+
original_path (str): The full URL path from the incoming request.
50+
51+
Returns:
52+
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
53+
"""
54+
normalized_path = original_path.lower().rstrip("/")
55+
56+
match = re.search(r"/predict(withresponsestream)?", normalized_path)
57+
if not match:
58+
logger.debug("Route header cannot be resolved from path: %s", original_path)
59+
return ""
60+
61+
route_suffix = normalized_path[match.end() :].lstrip("/")
62+
if not route_suffix:
63+
logger.warning(
64+
"Missing OpenAI route suffix after '/predict'. "
65+
"Expected something like '/v1/completions'."
66+
)
67+
return ""
68+
69+
if not route_suffix.startswith("v"):
70+
logger.warning(
71+
"Route suffix does not start with a version prefix (e.g., '/v1'). "
72+
"This may lead to compatibility issues with OpenAI-style endpoints. "
73+
"Consider updating the URL to include a version prefix, "
74+
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
75+
)
76+
# route_suffix = f"v1/{route_suffix}"
77+
78+
return f"/{route_suffix}"
79+
80+
def _patch_streaming(self, request: httpx.Request) -> None:
81+
"""
82+
Sets the 'enable-streaming' header based on the JSON request body contents.
83+
84+
If the request body contains `"stream": true`, the `enable-streaming` header is set to "true".
85+
Otherwise, it defaults to "false".
86+
87+
Args:
88+
request (httpx.Request): The outgoing HTTPX request.
89+
"""
90+
streaming_enabled = "false"
91+
content_type = request.headers.get("Content-Type", "")
92+
93+
if "application/json" in content_type and request.content:
94+
try:
95+
body = (
96+
request.content.decode("utf-8")
97+
if isinstance(request.content, bytes)
98+
else request.content
99+
)
100+
payload = json.loads(body)
101+
if payload.get("stream") is True:
102+
streaming_enabled = "true"
103+
except Exception as e:
104+
logger.exception(
105+
"Failed to parse request JSON body for streaming flag: %s", e
106+
)
107+
108+
request.headers.setdefault("enable-streaming", streaming_enabled)
109+
logger.debug("Patched 'enable-streaming' header: %s", streaming_enabled)
110+
111+
def _patch_headers(self, request: httpx.Request) -> None:
112+
"""
113+
Patches request headers by injecting OpenAI-compatible values:
114+
- `enable-streaming` for streaming-aware endpoints
115+
- `route` for backend routing
116+
117+
Args:
118+
request (httpx.Request): The outgoing HTTPX request.
119+
"""
120+
self._patch_streaming(request)
121+
route_header = self._patch_route(request.url.path)
122+
request.headers.setdefault("route", route_header)
123+
logger.debug("Patched 'route' header: %s", route_header)
124+
125+
def _patch_url(self) -> httpx.URL:
126+
"""
127+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
128+
129+
Returns:
130+
httpx.URL: The normalized base URL with the correct model deployment path.
131+
"""
132+
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133+
match = re.search(r"/predict(withresponsestream)?/", base_path)
134+
if match:
135+
trimmed = base_path[: match.end() - 1]
136+
return self.base_url.copy_with(path=trimmed)
137+
138+
logger.debug("Could not determine a valid endpoint from path: %s", base_path)
139+
return self.base_url
140+
141+
def _prepare_request_common(self, request: httpx.Request) -> None:
142+
"""
143+
Common preparation routine for all requests.
144+
145+
This includes:
146+
- Patching headers with streaming and routing info.
147+
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
148+
149+
Args:
150+
request (httpx.Request): The outgoing HTTPX request.
151+
"""
152+
# Patch headers
153+
logger.debug("Original headers: %s", request.headers)
154+
self._patch_headers(request)
155+
logger.debug("Headers after patching: %s", request.headers)
156+
157+
# Patch URL
158+
logger.debug("URL before patching: %s", request.url)
159+
request.url = self._patch_url()
160+
logger.debug("URL after patching: %s", request.url)
161+
162+
163+
class OpenAI(openai.OpenAI, AquaOpenAIMixin):
164+
def __init__(
165+
self,
166+
*,
167+
api_key: Optional[str] = None,
168+
organization: Optional[str] = None,
169+
project: Optional[str] = None,
170+
base_url: Optional[Union[str, httpx.URL]] = None,
171+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
172+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
173+
max_retries: int = DEFAULT_MAX_RETRIES,
174+
default_headers: Optional[Dict[str, str]] = None,
175+
default_query: Optional[Dict[str, object]] = None,
176+
http_client: Optional[httpx.Client] = None,
177+
http_client_kwargs: Optional[Dict[str, Any]] = None,
178+
_strict_response_validation: bool = False,
179+
**kwargs: Any,
180+
) -> None:
181+
"""
182+
Construct a new synchronous OpenAI client instance.
183+
184+
If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client().
185+
186+
Args:
187+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
188+
organization (str, optional): Organization ID. Defaults to env variable OPENAI_ORG_ID.
189+
project (str, optional): Project ID. Defaults to env variable OPENAI_PROJECT_ID.
190+
base_url (str | httpx.URL, optional): Base URL for the API.
191+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
192+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
193+
max_retries (int, optional): Maximum number of retries for API requests.
194+
default_headers (dict[str, str], optional): Additional headers.
195+
default_query (dict[str, object], optional): Additional query parameters.
196+
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198+
_strict_response_validation (bool, optional): Enable strict response validation.
199+
**kwargs: Additional keyword arguments passed to the parent __init__.
200+
"""
201+
if http_client is None:
202+
logger.debug(
203+
"No http_client provided; auto-creating one using ads.aqua.get_httpx_client()"
204+
)
205+
http_client = get_httpx_client(**(http_client_kwargs or {}))
206+
if not api_key:
207+
logger.debug("API key not provided; using default placeholder for OCI.")
208+
api_key = "OCI"
209+
210+
super().__init__(
211+
api_key=api_key,
212+
organization=organization,
213+
project=project,
214+
base_url=base_url,
215+
websocket_base_url=websocket_base_url,
216+
timeout=timeout,
217+
max_retries=max_retries,
218+
default_headers=default_headers,
219+
default_query=default_query,
220+
http_client=http_client,
221+
_strict_response_validation=_strict_response_validation,
222+
**kwargs,
223+
)
224+
225+
def _prepare_request(self, request: httpx.Request) -> None:
226+
"""
227+
Prepare the synchronous HTTP request by applying common modifications.
228+
229+
Args:
230+
request (httpx.Request): The outgoing HTTP request.
231+
"""
232+
self._prepare_request_common(request)
233+
234+
235+
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
236+
def __init__(
237+
self,
238+
*,
239+
api_key: Optional[str] = None,
240+
organization: Optional[str] = None,
241+
project: Optional[str] = None,
242+
base_url: Optional[Union[str, httpx.URL]] = None,
243+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
244+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
245+
max_retries: int = DEFAULT_MAX_RETRIES,
246+
default_headers: Optional[Dict[str, str]] = None,
247+
default_query: Optional[Dict[str, object]] = None,
248+
http_client: Optional[httpx.Client] = None,
249+
http_client_kwargs: Optional[Dict[str, Any]] = None,
250+
_strict_response_validation: bool = False,
251+
**kwargs: Any,
252+
) -> None:
253+
"""
254+
Construct a new asynchronous AsyncOpenAI client instance.
255+
256+
If no http_client is provided, one will be automatically created using
257+
ads.aqua.get_async_httpx_client().
258+
259+
Args:
260+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
261+
organization (str, optional): Organization ID.
262+
project (str, optional): Project ID.
263+
base_url (str | httpx.URL, optional): Base URL for the API.
264+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
265+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
266+
max_retries (int, optional): Maximum number of retries for API requests.
267+
default_headers (dict[str, str], optional): Additional headers.
268+
default_query (dict[str, object], optional): Additional query parameters.
269+
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271+
_strict_response_validation (bool, optional): Enable strict response validation.
272+
**kwargs: Additional keyword arguments passed to the parent __init__.
273+
"""
274+
if http_client is None:
275+
logger.debug(
276+
"No async http_client provided; auto-creating one using ads.aqua.get_async_httpx_client()"
277+
)
278+
http_client = get_async_httpx_client(**(http_client_kwargs or {}))
279+
if not api_key:
280+
logger.debug("API key not provided; using default placeholder for OCI.")
281+
api_key = "OCI"
282+
283+
super().__init__(
284+
api_key=api_key,
285+
organization=organization,
286+
project=project,
287+
base_url=base_url,
288+
websocket_base_url=websocket_base_url,
289+
timeout=timeout,
290+
max_retries=max_retries,
291+
default_headers=default_headers,
292+
default_query=default_query,
293+
http_client=http_client,
294+
_strict_response_validation=_strict_response_validation,
295+
**kwargs,
296+
)
297+
298+
async def _prepare_request(self, request: httpx.Request) -> None:
299+
"""
300+
Asynchronously prepare the HTTP request by applying common modifications.
301+
302+
Args:
303+
request (httpx.Request): The outgoing HTTP request.
304+
"""
305+
self._prepare_request_common(request)

0 commit comments

Comments
 (0)