Skip to content

Commit a7c0b1c

Browse files
committed
Accept uuidrepr in URI
1 parent ac8ba50 commit a7c0b1c

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

mongoengine/connection.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import warnings
22

33
from pymongo import MongoClient, ReadPreference, uri_parser
4+
from pymongo.common import (
5+
_UUID_REPRESENTATIONS,
6+
_CaseInsensitiveDictionary,
7+
)
48
from pymongo.database import _check_name
59

610
from mongoengine.pymongo_support import PYMONGO_VERSION
@@ -124,7 +128,7 @@ def _get_connection_settings(
124128
if uri_dict.get(param):
125129
conn_settings[param] = uri_dict[param]
126130

127-
uri_options = uri_dict["options"]
131+
uri_options: _CaseInsensitiveDictionary = uri_dict["options"]
128132
if "replicaset" in uri_options:
129133
conn_settings["replicaSet"] = uri_options["replicaset"]
130134
if "authsource" in uri_options:
@@ -159,6 +163,13 @@ def _get_connection_settings(
159163
conn_settings["authmechanismproperties"] = uri_options[
160164
"authmechanismproperties"
161165
]
166+
if "uuidrepresentation" in uri_options:
167+
REV_UUID_REPRESENTATIONS = {
168+
v: k for k, v in _UUID_REPRESENTATIONS.items()
169+
}
170+
conn_settings["uuidrepresentation"] = REV_UUID_REPRESENTATIONS[
171+
uri_options["uuidrepresentation"]
172+
]
162173
else:
163174
resolved_hosts.append(entity)
164175
conn_settings["host"] = resolved_hosts
@@ -170,7 +181,7 @@ def _get_connection_settings(
170181
keys = {
171182
key.lower() for key in kwargs.keys()
172183
} # pymongo options are case insensitive
173-
if "uuidrepresentation" not in keys:
184+
if "uuidrepresentation" not in keys and "uuidrepresentation" not in conn_settings:
174185
warnings.warn(
175186
"No uuidRepresentation is specified! Falling back to "
176187
"'pythonLegacy' which is the default for pymongo 3.x. "

tests/test_connection.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import unittest
3+
import uuid
34

45
import pymongo
56
import pytest
@@ -30,6 +31,10 @@
3031
from mongoengine.pymongo_support import PYMONGO_VERSION
3132

3233

34+
def random_str():
35+
return str(uuid.uuid4())
36+
37+
3338
def get_tz_awareness(connection):
3439
return connection.codec_options.tz_aware
3540

@@ -624,6 +629,50 @@ def test_connect_2_databases_uses_different_client_if_different_parameters(self)
624629
c2 = connect(alias="testdb2", db="testdb2", username="u2", password="pass")
625630
assert c1 is not c2
626631

632+
def test_connect_uri_uuidrepresentation_set_in_uri(self):
633+
rand = random_str()
634+
tmp_conn = connect(
635+
alias=rand,
636+
host=f"mongodb://localhost:27017/{rand}?uuidRepresentation=csharpLegacy",
637+
)
638+
assert (
639+
tmp_conn.options.codec_options.uuid_representation
640+
== pymongo.common.UuidRepresentation.CSHARP_LEGACY
641+
)
642+
disconnect(rand)
643+
644+
def test_connect_uri_uuidrepresentation_set_as_arg(self):
645+
rand = random_str()
646+
tmp_conn = connect(alias=rand, db=rand, uuidRepresentation="javaLegacy")
647+
assert (
648+
tmp_conn.options.codec_options.uuid_representation
649+
== pymongo.common.UuidRepresentation.JAVA_LEGACY
650+
)
651+
disconnect(rand)
652+
653+
def test_connect_uri_uuidrepresentation_set_both_arg_and_uri_arg_prevail(self):
654+
rand = random_str()
655+
tmp_conn = connect(
656+
alias=rand,
657+
host=f"mongodb://localhost:27017/{rand}?uuidRepresentation=csharpLegacy",
658+
uuidRepresentation="javaLegacy",
659+
)
660+
assert (
661+
tmp_conn.options.codec_options.uuid_representation
662+
== pymongo.common.UuidRepresentation.JAVA_LEGACY
663+
)
664+
disconnect(rand)
665+
666+
def test_connect_uri_uuidrepresentation_default_to_pythonlegacy(self):
667+
# To be changed soon to unspecified
668+
rand = random_str()
669+
tmp_conn = connect(alias=rand, db=rand)
670+
assert (
671+
tmp_conn.options.codec_options.uuid_representation
672+
== pymongo.common.UuidRepresentation.PYTHON_LEGACY
673+
)
674+
disconnect(rand)
675+
627676

628677
if __name__ == "__main__":
629678
unittest.main()

0 commit comments

Comments
 (0)