Skip to content

Commit f913cfa

Browse files
committed
test: Add test_should_filter_by_reference_field
1 parent 5d978e8 commit f913cfa

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

graphene_mongo/fields.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from graphene.types.argument import to_arguments
1212

1313

14+
from .utils import get_model_reference_fields
15+
16+
1417
# noqa
1518
class MongoengineListField(Field):
1619

@@ -60,7 +63,8 @@ def model(self):
6063
@property
6164
def args(self):
6265
return to_arguments(
63-
self._base_args or OrderedDict(), dict(self.field_args.items() + self.reference_args.items())
66+
self._base_args or OrderedDict(),
67+
dict(self.field_args.items() + self.reference_args.items())
6468
)
6569

6670
@args.setter
@@ -103,8 +107,15 @@ def get_query(cls, model, info, **args):
103107
return []
104108

105109
objs = model.objects()
106-
107110
if args:
111+
reference_fields = get_model_reference_fields(model)
112+
for arg_name, arg in args.items():
113+
if arg_name in reference_fields:
114+
reference_model = model._fields[arg_name]
115+
pk = from_global_id(args.pop(arg_name))[-1]
116+
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
117+
args[arg_name] = reference_obj
118+
108119
first = args.pop('first', None)
109120
last = args.pop('last', None)
110121
id = args.pop('id', None)
@@ -132,6 +143,7 @@ def get_query(cls, model, info, **args):
132143
# https://github.com/graphql-python/graphene-mongo/issues/20
133144
objs = objs[-(last+1):]
134145

146+
print(objs)
135147
return objs
136148

137149
# noqa

graphene_mongo/tests/test_relay_query.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,46 @@ class Query(graphene.ObjectType):
295295
assert result.data == expected
296296

297297

298+
def test_should_filter_by_reference_field():
299+
300+
class Query(graphene.ObjectType):
301+
node = Node.Field()
302+
articles = MongoengineConnectionField(ArticleNode)
303+
304+
query = '''
305+
query ArticlesQuery {
306+
articles(editor: "RWRpdG9yTm9kZTox") {
307+
edges {
308+
node {
309+
headline,
310+
editor {
311+
firstName
312+
}
313+
}
314+
}
315+
}
316+
}
317+
'''
318+
expected = {
319+
'articles': {
320+
'edges': [
321+
{
322+
'node': {
323+
'headline': 'Hello',
324+
'editor': {
325+
'firstName': 'Penny'
326+
}
327+
}
328+
}
329+
]
330+
}
331+
}
332+
schema = graphene.Schema(query=Query)
333+
result = schema.execute(query)
334+
assert not result.errors
335+
assert result.data == expected
336+
337+
298338
def test_should_filter_through_inheritance():
299339

300340
class Query(graphene.ObjectType):

graphene_mongo/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66

77
def get_model_fields(model, excluding=None):
8-
if excluding is None:
9-
excluding = []
8+
excluding = excluding or []
109
attributes = dict()
1110
for attr_name, attr in model._fields.items():
1211
if attr_name in excluding:
@@ -15,6 +14,17 @@ def get_model_fields(model, excluding=None):
1514
return OrderedDict(sorted(attributes.items()))
1615

1716

17+
def get_model_reference_fields(model, excluding=None):
18+
excluding = excluding or []
19+
attributes = dict()
20+
for attr_name, attr in model._fields.items():
21+
if attr_name in excluding \
22+
or not isinstance(attr, mongoengine.fields.ReferenceField):
23+
continue
24+
attributes[attr_name] = attr
25+
return attributes
26+
27+
1828
def is_valid_mongoengine_model(model):
1929
return inspect.isclass(model) and (
2030
issubclass(model, mongoengine.Document) or issubclass(model, mongoengine.EmbeddedDocument)

0 commit comments

Comments
 (0)