Skip to content

Commit 9634e44

Browse files
committed
Fix the issue that the same MongoClient gets re-used in case we connect to 2 databases on the same host (problematic when different users authenticate)
1 parent 048a045 commit 9634e44

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

mongoengine/connection.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
235235
raise MongoEngineConnectionError(msg)
236236

237237
def _clean_settings(settings_dict):
238-
# set literal more efficient than calling set function
239238
irrelevant_fields_set = {
240239
'name', 'username', 'password',
241240
'authentication_source', 'authentication_mechanism'
@@ -245,10 +244,11 @@ def _clean_settings(settings_dict):
245244
if k not in irrelevant_fields_set
246245
}
247246

247+
raw_conn_settings = _connection_settings[alias].copy()
248248
# Retrieve a copy of the connection settings associated with the requested
249249
# alias and remove the database name and authentication info (we don't
250250
# care about them at this point).
251-
conn_settings = _clean_settings(_connection_settings[alias].copy())
251+
conn_settings = _clean_settings(raw_conn_settings)
252252

253253
# Determine if we should use PyMongo's or mongomock's MongoClient.
254254
is_mock = conn_settings.pop('is_mock', False)
@@ -262,19 +262,8 @@ def _clean_settings(settings_dict):
262262
else:
263263
connection_class = MongoClient
264264

265-
# Iterate over all of the connection settings and if a connection with
266-
# the same parameters is already established, use it instead of creating
267-
# a new one.
268-
existing_connection = None
269-
connection_settings_iterator = (
270-
(db_alias, settings.copy())
271-
for db_alias, settings in _connection_settings.items()
272-
)
273-
for db_alias, connection_settings in connection_settings_iterator:
274-
connection_settings = _clean_settings(connection_settings)
275-
if conn_settings == connection_settings and _connections.get(db_alias):
276-
existing_connection = _connections[db_alias]
277-
break
265+
# Re-use existing connection if one is suitable
266+
existing_connection = _find_existing_connection(raw_conn_settings)
278267

279268
# If an existing connection was found, assign it to the new alias
280269
if existing_connection:
@@ -291,6 +280,44 @@ def _clean_settings(settings_dict):
291280
return _connections[alias]
292281

293282

283+
def _create_connection(connection_class, **connection_settings):
284+
# Otherwise, create the new connection for this alias. Raise
285+
# MongoEngineConnectionError if it can't be established.
286+
try:
287+
_connections[alias] = connection_class(**conn_settings)
288+
except Exception as e:
289+
raise MongoEngineConnectionError(
290+
'Cannot connect to database %s :\n%s' % (alias, e))
291+
292+
293+
def _find_existing_connection(connection_settings):
294+
"""
295+
Check if an existing connection could be reused
296+
297+
Iterate over all of the connection settings and if an existing connection
298+
with the same parameters is suitable, return it
299+
300+
:param connection_settings: the settings of the new connection
301+
:return: An existing connection or None
302+
"""
303+
connection_settings_iterator = (
304+
(db_alias, settings.copy())
305+
for db_alias, settings in _connection_settings.items()
306+
)
307+
308+
def _clean_settings(settings_dict):
309+
# Only remove the name but it's important to
310+
# keep the username/password/authentication_source/authentication_mechanism
311+
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
312+
return {k: v for k, v in settings_dict.items() if k != 'name'}
313+
314+
cleaned_conn_settings = _clean_settings(connection_settings)
315+
for db_alias, connection_settings in connection_settings_iterator:
316+
db_conn_settings = _clean_settings(connection_settings)
317+
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
318+
return _connections[db_alias]
319+
320+
294321
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
295322
if reconnect:
296323
disconnect(alias)

0 commit comments

Comments
 (0)