Skip to content

Commit 04f2d69

Browse files
authored
improve confluence doc loader param validation (langchain-ai#9568)
1 parent 0fea987 commit 04f2d69

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

libs/langchain/langchain/document_loaders/confluence.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,15 @@ def __init__(
118118
):
119119
confluence_kwargs = confluence_kwargs or {}
120120
errors = ConfluenceLoader.validate_init_args(
121-
url, api_key, username, oauth2, token
121+
url=url,
122+
api_key=api_key,
123+
username=username,
124+
session=session,
125+
oauth2=oauth2,
126+
token=token,
122127
)
123128
if errors:
124129
raise ValueError(f"Error(s) while validating input: {errors}")
125-
126-
self.base_url = url
127-
self.number_of_retries = number_of_retries
128-
self.min_retry_seconds = min_retry_seconds
129-
self.max_retry_seconds = max_retry_seconds
130-
131130
try:
132131
from atlassian import Confluence # noqa: F401
133132
except ImportError:
@@ -136,6 +135,11 @@ def __init__(
136135
"`pip install atlassian-python-api`"
137136
)
138137

138+
self.base_url = url
139+
self.number_of_retries = number_of_retries
140+
self.min_retry_seconds = min_retry_seconds
141+
self.max_retry_seconds = max_retry_seconds
142+
139143
if session:
140144
self.confluence = Confluence(url=url, session=session, **confluence_kwargs)
141145
elif oauth2:
@@ -160,6 +164,7 @@ def validate_init_args(
160164
url: Optional[str] = None,
161165
api_key: Optional[str] = None,
162166
username: Optional[str] = None,
167+
session: Optional[requests.Session] = None,
163168
oauth2: Optional[dict] = None,
164169
token: Optional[str] = None,
165170
) -> Union[List, None]:
@@ -175,33 +180,28 @@ def validate_init_args(
175180
"the other must be as well."
176181
)
177182

178-
if (api_key or username) and oauth2:
183+
non_null_creds = list(
184+
x is not None for x in ((api_key or username), session, oauth2, token)
185+
)
186+
if sum(non_null_creds) > 1:
187+
all_names = ("(api_key, username)", "session", "oath2", "token")
188+
provided = tuple(n for x, n in zip(non_null_creds, all_names) if x)
179189
errors.append(
180-
"Cannot provide a value for `api_key` and/or "
181-
"`username` and provide a value for `oauth2`"
190+
f"Cannot provide a value for more than one of: {all_names}. Received "
191+
f"values for: {provided}"
182192
)
183-
184-
if oauth2 and oauth2.keys() != [
193+
if oauth2 and set(oauth2.keys()) != {
185194
"access_token",
186195
"access_token_secret",
187196
"consumer_key",
188197
"key_cert",
189-
]:
198+
}:
190199
errors.append(
191200
"You have either omitted require keys or added extra "
192201
"keys to the oauth2 dictionary. key values should be "
193202
"`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`"
194203
)
195-
196-
if token and (api_key or username or oauth2):
197-
errors.append(
198-
"Cannot provide a value for `token` and a value for `api_key`, "
199-
"`username` or `oauth2`"
200-
)
201-
202-
if errors:
203-
return errors
204-
return None
204+
return errors or None
205205

206206
def load(
207207
self,

libs/langchain/tests/unit_tests/document_loaders/test_confluence.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import MagicMock, patch
44

55
import pytest
6+
import requests
67

78
from langchain.docstore.document import Document
89
from langchain.document_loaders.confluence import ConfluenceLoader
@@ -23,7 +24,7 @@ class TestConfluenceLoader:
2324

2425
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
2526
ConfluenceLoader(
26-
url=self.CONFLUENCE_URL,
27+
self.CONFLUENCE_URL,
2728
username=self.MOCK_USERNAME,
2829
api_key=self.MOCK_API_TOKEN,
2930
)
@@ -34,6 +35,36 @@ def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> N
3435
cloud=True,
3536
)
3637

38+
def test_confluence_loader_initialization_invalid(self) -> None:
39+
with pytest.raises(ValueError):
40+
ConfluenceLoader(
41+
self.CONFLUENCE_URL,
42+
username=self.MOCK_USERNAME,
43+
api_key=self.MOCK_API_TOKEN,
44+
token="foo",
45+
)
46+
47+
with pytest.raises(ValueError):
48+
ConfluenceLoader(
49+
self.CONFLUENCE_URL,
50+
username=self.MOCK_USERNAME,
51+
api_key=self.MOCK_API_TOKEN,
52+
oauth2={
53+
"access_token": "bar",
54+
"access_token_secret": "bar",
55+
"consumer_key": "bar",
56+
"key_cert": "bar",
57+
},
58+
)
59+
60+
with pytest.raises(ValueError):
61+
ConfluenceLoader(
62+
self.CONFLUENCE_URL,
63+
username=self.MOCK_USERNAME,
64+
api_key=self.MOCK_API_TOKEN,
65+
session=requests.Session(),
66+
)
67+
3768
def test_confluence_loader_initialization_from_env(
3869
self, mock_confluence: MagicMock
3970
) -> None:
@@ -51,7 +82,7 @@ def test_confluence_loader_initialization_from_env(
5182

5283
def test_confluence_loader_load_data_invalid_args(self) -> None:
5384
confluence_loader = ConfluenceLoader(
54-
url=self.CONFLUENCE_URL,
85+
self.CONFLUENCE_URL,
5586
username=self.MOCK_USERNAME,
5687
api_key=self.MOCK_API_TOKEN,
5788
)
@@ -125,7 +156,7 @@ def _get_mock_confluence_loader(
125156
self, mock_confluence: MagicMock
126157
) -> ConfluenceLoader:
127158
confluence_loader = ConfluenceLoader(
128-
url=self.CONFLUENCE_URL,
159+
self.CONFLUENCE_URL,
129160
username=self.MOCK_USERNAME,
130161
api_key=self.MOCK_API_TOKEN,
131162
)

0 commit comments

Comments
 (0)