Skip to content

Commit 0ac59c6

Browse files
authored
Merge pull request #2068 from bagerard/fix_connection_auth_same_host
Fix connection issue when using different authentication in different dbs
2 parents 8e8c74c + 36aebff commit 0ac59c6

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Development
1313
- expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all`
1414
- Fix disconnect function #566 #1599 #605 #607 #1213 #565
1515
- Improve connect/disconnect documentations
16+
- Fix issue when using multiple connections to the same mongo with different credentials #2047
1617
- POTENTIAL BREAKING CHANGES: (associated with connect/disconnect fixes)
1718
- calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first)
1819
- disconnect now clears `mongoengine.connection._connection_settings`

mongoengine/connection.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
235235
raise MongoEngineConnectionError(msg)
236236

237237
def _clean_settings(settings_dict):
238-
# set literal more efficient than calling set function
239238
irrelevant_fields_set = {
240239
'name', 'username', 'password',
241240
'authentication_source', 'authentication_mechanism'
@@ -245,10 +244,12 @@ def _clean_settings(settings_dict):
245244
if k not in irrelevant_fields_set
246245
}
247246

247+
raw_conn_settings = _connection_settings[alias].copy()
248+
248249
# Retrieve a copy of the connection settings associated with the requested
249250
# alias and remove the database name and authentication info (we don't
250251
# care about them at this point).
251-
conn_settings = _clean_settings(_connection_settings[alias].copy())
252+
conn_settings = _clean_settings(raw_conn_settings)
252253

253254
# Determine if we should use PyMongo's or mongomock's MongoClient.
254255
is_mock = conn_settings.pop('is_mock', False)
@@ -262,35 +263,60 @@ def _clean_settings(settings_dict):
262263
else:
263264
connection_class = MongoClient
264265

265-
# Iterate over all of the connection settings and if a connection with
266-
# the same parameters is already established, use it instead of creating
267-
# a new one.
268-
existing_connection = None
269-
connection_settings_iterator = (
270-
(db_alias, settings.copy())
271-
for db_alias, settings in _connection_settings.items()
272-
)
273-
for db_alias, connection_settings in connection_settings_iterator:
274-
connection_settings = _clean_settings(connection_settings)
275-
if conn_settings == connection_settings and _connections.get(db_alias):
276-
existing_connection = _connections[db_alias]
277-
break
266+
# Re-use existing connection if one is suitable
267+
existing_connection = _find_existing_connection(raw_conn_settings)
278268

279269
# If an existing connection was found, assign it to the new alias
280270
if existing_connection:
281271
_connections[alias] = existing_connection
282272
else:
283-
# Otherwise, create the new connection for this alias. Raise
284-
# MongoEngineConnectionError if it can't be established.
285-
try:
286-
_connections[alias] = connection_class(**conn_settings)
287-
except Exception as e:
288-
raise MongoEngineConnectionError(
289-
'Cannot connect to database %s :\n%s' % (alias, e))
273+
_connections[alias] = _create_connection(alias=alias,
274+
connection_class=connection_class,
275+
**conn_settings)
290276

291277
return _connections[alias]
292278

293279

280+
def _create_connection(alias, connection_class, **connection_settings):
281+
"""
282+
Create the new connection for this alias. Raise
283+
MongoEngineConnectionError if it can't be established.
284+
"""
285+
try:
286+
return connection_class(**connection_settings)
287+
except Exception as e:
288+
raise MongoEngineConnectionError(
289+
'Cannot connect to database %s :\n%s' % (alias, e))
290+
291+
292+
def _find_existing_connection(connection_settings):
293+
"""
294+
Check if an existing connection could be reused
295+
296+
Iterate over all of the connection settings and if an existing connection
297+
with the same parameters is suitable, return it
298+
299+
:param connection_settings: the settings of the new connection
300+
:return: An existing connection or None
301+
"""
302+
connection_settings_bis = (
303+
(db_alias, settings.copy())
304+
for db_alias, settings in _connection_settings.items()
305+
)
306+
307+
def _clean_settings(settings_dict):
308+
# Only remove the name but it's important to
309+
# keep the username/password/authentication_source/authentication_mechanism
310+
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
311+
return {k: v for k, v in settings_dict.items() if k != 'name'}
312+
313+
cleaned_conn_settings = _clean_settings(connection_settings)
314+
for db_alias, connection_settings in connection_settings_bis:
315+
db_conn_settings = _clean_settings(connection_settings)
316+
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
317+
return _connections[db_alias]
318+
319+
294320
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
295321
if reconnect:
296322
disconnect(alias)

tests/test_connection.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,16 @@ def test_multiple_connection_settings(self):
611611
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
612612
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
613613

614+
def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self):
615+
c1 = connect(alias='testdb1', db='testdb1')
616+
c2 = connect(alias='testdb2', db='testdb2')
617+
self.assertIs(c1, c2)
618+
619+
def test_connect_2_databases_uses_different_client_if_different_parameters(self):
620+
c1 = connect(alias='testdb1', db='testdb1', username='u1')
621+
c2 = connect(alias='testdb2', db='testdb2', username='u2')
622+
self.assertIsNot(c1, c2)
623+
614624

615625
if __name__ == '__main__':
616626
unittest.main()

0 commit comments

Comments
 (0)