Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit df7d27f

Browse files
authored
set_current_environment to support openai > 1.0 (Azure#33636)
* set_current_environment to support openai 1.0 * Changelog * Update azure-ai-ml version * Update azure-ai-ml version * Stop using pkg_resources
1 parent e57716f commit df7d27f

File tree

7 files changed

+105
-19
lines changed

7 files changed

+105
-19
lines changed

sdk/ai/azure-ai-resources/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- AzureOpenAIConnection.set_current_environment supports openai 1.0 and above.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/ai/azure-ai-resources/azure/ai/resources/entities/connection_subtypes.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# ---------------------------------------------------------
44

55

6+
import os
67
from typing import Optional
78

89
from azure.ai.ml.entities._credentials import ApiKeyConfiguration
@@ -99,16 +100,33 @@ def set_current_environment(self, credential: Optional[TokenCredential] = None):
99100
:type credential: :class:`~azure.core.credentials.TokenCredential`
100101
"""
101102

102-
import os
103-
def get_api_version_case_insensitive(connection):
104-
if connection.api_version == None:
105-
raise ValueError(f"Connection {connection.name} is being used to set environment variables, but lacks required api_version")
106-
return connection.api_version.lower()
103+
from importlib.metadata import version as get_version
104+
from packaging.version import Version
107105

108-
try:
109-
import openai
110-
except ImportError:
111-
raise Exception("OpenAI SDK not installed. Please install it using `pip install openai`")
106+
openai_version_str = get_version("openai")
107+
openai_version = Version(openai_version_str)
108+
if openai_version >= Version("1.0.0"):
109+
self._set_current_environment_new(credential)
110+
else:
111+
self._set_current_environment_old(credential)
112+
113+
def _get_api_version_case_insensitive(self, connection):
114+
if connection.api_version == None:
115+
raise ValueError(f"Connection {connection.name} is being used to set environment variables, but lacks required api_version")
116+
return connection.api_version.lower()
117+
118+
def _set_current_environment_new(self, credential: Optional[TokenCredential] = None):
119+
if not credential:
120+
os.environ["AZURE_OPENAI_API_KEY"] = self._workspace_connection.credentials.key
121+
else:
122+
token = credential.get_token("https://cognitiveservices.azure.com/.default")
123+
os.environ["AZURE_OPENAI_AD_TOKEN"] = token.token
124+
125+
os.environ["OPENAI_API_VERSION"] = self._get_api_version_case_insensitive(self._workspace_connection)
126+
os.environ["AZURE_OPENAI_ENDPOINT"] = self._workspace_connection.target
127+
128+
def _set_current_environment_old(self, credential: Optional[TokenCredential] = None):
129+
import openai
112130

113131
if not credential:
114132
openai.api_type = "azure"
@@ -123,13 +141,12 @@ def get_api_version_case_insensitive(connection):
123141
openai.api_key = token.token
124142
os.environ["OPENAI_API_KEY"] = token.token
125143

126-
openai.api_version = get_api_version_case_insensitive(self._workspace_connection)
144+
openai.api_version = self._get_api_version_case_insensitive(self._workspace_connection)
127145

128146
openai.api_base = self._workspace_connection.target
129147

130148
os.environ["OPENAI_API_BASE"] = self._workspace_connection.target
131-
os.environ["OPENAI_API_VERSION"] = get_api_version_case_insensitive(self._workspace_connection)
132-
149+
os.environ["OPENAI_API_VERSION"] = self._get_api_version_case_insensitive(self._workspace_connection)
133150

134151
class AzureAISearchConnection(BaseConnection):
135152
"""A Connection for Azure AI Search

sdk/ai/azure-ai-resources/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
python_requires="<4.0,>=3.8",
6767
install_requires=[
6868
# NOTE: To avoid breaking changes in a major version bump, all dependencies should pin an upper bound if possible.
69-
"azure-ai-ml>=1.12.0",
69+
"azure-ai-ml>=1.13.0",
7070
"mlflow-skinny<3",
7171
"opencensus-ext-logging<=0.1.1",
7272
"azure-mgmt-resource<23.0.0,>=22.0.0",
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import unittest
3+
from unittest.mock import Mock, patch
4+
5+
from azure.ai.resources.entities import AzureOpenAIConnection
6+
from azure.ai.ml.entities._credentials import ApiKeyConfiguration
7+
8+
9+
@pytest.mark.unittest
10+
class TestConnection:
11+
12+
@pytest.fixture()
13+
def mock_credentials(self):
14+
mock_credentials = Mock()
15+
mock_token = Mock()
16+
mock_credentials.get_token.return_value = mock_token
17+
mock_token.token = "my-aad-token"
18+
return mock_credentials
19+
20+
@patch("importlib.metadata.version")
21+
@patch("os.environ")
22+
def test_set_environment(self, mock_os_environ, mock_get_version, mock_credentials):
23+
mock_get_version.return_value = "0.28.1"
24+
connection = AzureOpenAIConnection(name="my-connection", target="https://test.com", credentials=ApiKeyConfiguration(key="abc"), api_version="2021-01-01")
25+
26+
connection.set_current_environment()
27+
assert mock_os_environ.__setitem__.call_count == 4
28+
assert mock_os_environ.__setitem__.call_args_list[0][0] == ("OPENAI_API_KEY", "abc")
29+
assert mock_os_environ.__setitem__.call_args_list[1][0] == ("OPENAI_API_TYPE", "azure")
30+
assert mock_os_environ.__setitem__.call_args_list[2][0] == ("OPENAI_API_BASE", "https://test.com")
31+
assert mock_os_environ.__setitem__.call_args_list[3][0] == ("OPENAI_API_VERSION", "2021-01-01")
32+
33+
connection.set_current_environment(mock_credentials)
34+
assert mock_os_environ.__setitem__.call_count == 8
35+
assert mock_os_environ.__setitem__.call_args_list[4][0] == ("OPENAI_API_TYPE", "azure_ad")
36+
assert mock_os_environ.__setitem__.call_args_list[5][0] == ("OPENAI_API_KEY", "my-aad-token")
37+
assert mock_os_environ.__setitem__.call_args_list[6][0] == ("OPENAI_API_BASE", "https://test.com")
38+
assert mock_os_environ.__setitem__.call_args_list[7][0] == ("OPENAI_API_VERSION", "2021-01-01")
39+
40+
41+
@patch("importlib.metadata.version")
42+
@patch("os.environ")
43+
def test_set_environment_100(self, mock_os_environ, mock_get_version, mock_credentials):
44+
mock_get_version.return_value = "1.0.0"
45+
connection = AzureOpenAIConnection(name="my-connection", target="https://test.com", credentials=ApiKeyConfiguration(key="abc"), api_version="2021-01-01")
46+
47+
connection.set_current_environment()
48+
# This is a hacky way to verity os.environ, any better way?
49+
assert mock_os_environ.__setitem__.call_count == 3
50+
# call_args[0] is args, call_args[1] is kwargs
51+
assert mock_os_environ.__setitem__.call_args_list[0][0] == ("AZURE_OPENAI_API_KEY", "abc")
52+
assert mock_os_environ.__setitem__.call_args_list[1][0] == ("OPENAI_API_VERSION", "2021-01-01")
53+
assert mock_os_environ.__setitem__.call_args_list[2][0] == ("AZURE_OPENAI_ENDPOINT", "https://test.com")
54+
55+
connection.set_current_environment(mock_credentials)
56+
assert mock_os_environ.__setitem__.call_count == 6
57+
assert mock_os_environ.__setitem__.call_args_list[3][0] == ("AZURE_OPENAI_AD_TOKEN", "my-aad-token")
58+
assert mock_os_environ.__setitem__.call_args_list[4][0] == ("OPENAI_API_VERSION", "2021-01-01")
59+
assert mock_os_environ.__setitem__.call_args_list[5][0] == ("AZURE_OPENAI_ENDPOINT", "https://test.com")

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# Release History
22

3-
## 1.12.0 (unreleased)
3+
## 1.13.0 (unreleased)
4+
5+
### Features Added
6+
7+
### Bugs Fixed
8+
9+
### Breaking Changes
10+
11+
### Other Changes
12+
13+
## 1.12.0 (2023-11-13)
414

515
### Features Added
616
- Workspace Connections had 3 child classes added for open AI, cog search, and cog service connections.

sdk/ml/azure-ai-ml/azure/ai/ml/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
VERSION = "1.12.0"
5+
VERSION = "1.13.0"

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_resource.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def __init__(
5454
self.tags = dict(tags) if tags else {}
5555
self.properties = dict(properties) if properties else {}
5656
# Conditional assignment to prevent entity bloat when unused.
57-
print_as_yaml = kwargs.pop("print_as_yaml", in_jupyter_notebook())
58-
if print_as_yaml:
59-
self.print_as_yaml = True
57+
self._print_as_yaml = kwargs.pop("print_as_yaml", False)
6058

6159
# Hide read only properties in kwargs
6260
self._id = kwargs.pop("id", None)
@@ -194,7 +192,7 @@ def __repr__(self) -> str:
194192
return f"{self.__class__.__name__}({var_dict})"
195193

196194
def __str__(self) -> str:
197-
if hasattr(self, "print_as_yaml") and self.print_as_yaml:
195+
if self._print_as_yaml or in_jupyter_notebook():
198196
# pylint: disable=no-member
199197
yaml_serialized = self._to_dict()
200198
return dump_yaml(yaml_serialized, default_flow_style=False)

0 commit comments

Comments
 (0)