Skip to content

Commit 07ca75f

Browse files
authored
Merge pull request #28 from graphql-python/feat-filter-by-reference
Feat filter by reference, fixed #25
2 parents f2ce892 + d89f823 commit 07ca75f

File tree

5 files changed

+104
-9
lines changed

5 files changed

+104
-9
lines changed

graphene_mongo/fields.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from graphene.types.argument import to_arguments
1212

1313

14+
from .utils import get_model_reference_fields
15+
16+
1417
# noqa
1518
class MongoengineListField(Field):
1619

@@ -60,26 +63,35 @@ def model(self):
6063
@property
6164
def args(self):
6265
return to_arguments(
63-
self._base_args or OrderedDict(), self.default_filter_args
66+
self._base_args or OrderedDict(),
67+
dict(self.field_args, **self.reference_args)
6468
)
6569

6670
@args.setter
6771
def args(self, args):
6872
self._base_args = args
6973

7074
@property
71-
def default_filter_args(self):
75+
def field_args(self):
7276
def is_filterable(kv):
7377
return hasattr(kv[1], '_type') \
7478
and callable(getattr(kv[1]._type, '_of_type', None))
7579

7680
return reduce(
7781
lambda r, kv: r.update(
7882
{kv[0]: kv[1]._type._of_type()}) or r if is_filterable(kv) else r,
79-
self.fields.items(),
80-
{}
83+
self.fields.items(), {}
8184
)
8285

86+
@property
87+
def reference_args(self):
88+
def get_reference_field(r, kv):
89+
if callable(getattr(kv[1], 'get_type', None)):
90+
node = kv[1].get_type()._type._meta
91+
r.update({kv[0]: node.fields['id']._type.of_type()})
92+
return r
93+
return reduce(get_reference_field, self.fields.items(), {})
94+
8395
@property
8496
def filter_fields(self):
8597
return self._type._meta.filter_fields
@@ -95,8 +107,17 @@ def get_query(cls, model, info, **args):
95107
return []
96108

97109
objs = model.objects()
98-
99110
if args:
111+
reference_fields = get_model_reference_fields(model)
112+
reference_args = {}
113+
for arg_name, arg in args.copy().items():
114+
if arg_name in reference_fields:
115+
reference_model = model._fields[arg_name]
116+
pk = from_global_id(args.pop(arg_name))[-1]
117+
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
118+
reference_args[arg_name] = reference_obj
119+
120+
args.update(reference_args)
100121
first = args.pop('first', None)
101122
last = args.pop('last', None)
102123
id = args.pop('id', None)
@@ -121,7 +142,7 @@ def get_query(cls, model, info, **args):
121142
if first is not None:
122143
objs = objs[:first]
123144
if last is not None:
124-
# fix for https://github.com/graphql-python/graphene-mongo/issues/20
145+
# https://github.com/graphql-python/graphene-mongo/issues/20
125146
objs = objs[-(last+1):]
126147

127148
return objs

graphene_mongo/tests/test_fields.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
from ..fields import MongoengineConnectionField
4+
from .types import ArticleNode
5+
6+
7+
def test_field_args():
8+
field = MongoengineConnectionField(ArticleNode)
9+
10+
field_args = ['id', 'headline', 'pub_date']
11+
assert set(field.field_args.keys()) == set(field_args)
12+
13+
reference_args = ['editor', 'reporter']
14+
assert set(field.reference_args.keys()) == set(reference_args)
15+
16+
default_args = ['after', 'last', 'first', 'before']
17+
args = field_args + reference_args + default_args
18+
assert set(field.args) == set(args)

graphene_mongo/tests/test_relay_query.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,46 @@ class Query(graphene.ObjectType):
295295
assert result.data == expected
296296

297297

298+
def test_should_filter_by_reference_field():
299+
300+
class Query(graphene.ObjectType):
301+
node = Node.Field()
302+
articles = MongoengineConnectionField(ArticleNode)
303+
304+
query = '''
305+
query ArticlesQuery {
306+
articles(editor: "RWRpdG9yTm9kZTox") {
307+
edges {
308+
node {
309+
headline,
310+
editor {
311+
firstName
312+
}
313+
}
314+
}
315+
}
316+
}
317+
'''
318+
expected = {
319+
'articles': {
320+
'edges': [
321+
{
322+
'node': {
323+
'headline': 'Hello',
324+
'editor': {
325+
'firstName': 'Penny'
326+
}
327+
}
328+
}
329+
]
330+
}
331+
}
332+
schema = graphene.Schema(query=Query)
333+
result = schema.execute(query)
334+
assert not result.errors
335+
assert result.data == expected
336+
337+
298338
def test_should_filter_through_inheritance():
299339

300340
class Query(graphene.ObjectType):

graphene_mongo/tests/test_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
from ..utils import (
66
get_model_fields, is_valid_mongoengine_model
77
)
8-
from .models import Reporter, Child
8+
from .models import Article, Reporter, Child
9+
910

1011
def test_get_model_fields_no_duplication():
1112
reporter_fields = get_model_fields(Reporter)
1213
reporter_name_set = set(reporter_fields)
1314
assert len(reporter_fields) == len(reporter_name_set)
1415

1516

17+
def test_get_model_relation_fields():
18+
article_fields = get_model_fields(Article)
19+
assert all(field in set(article_fields) for field in ['editor', 'reporter'])
20+
21+
1622
def test_get_base_model_fields():
1723
child_fields = get_model_fields(Child)
1824
assert all(field in set(child_fields) for field in ['bar', 'baz'])

graphene_mongo/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66

77
def get_model_fields(model, excluding=None):
8-
if excluding is None:
9-
excluding = []
8+
excluding = excluding or []
109
attributes = dict()
1110
for attr_name, attr in model._fields.items():
1211
if attr_name in excluding:
@@ -15,6 +14,17 @@ def get_model_fields(model, excluding=None):
1514
return OrderedDict(sorted(attributes.items()))
1615

1716

17+
def get_model_reference_fields(model, excluding=None):
18+
excluding = excluding or []
19+
attributes = dict()
20+
for attr_name, attr in model._fields.items():
21+
if attr_name in excluding \
22+
or not isinstance(attr, mongoengine.fields.ReferenceField):
23+
continue
24+
attributes[attr_name] = attr
25+
return attributes
26+
27+
1828
def is_valid_mongoengine_model(model):
1929
return inspect.isclass(model) and (
2030
issubclass(model, mongoengine.Document) or issubclass(model, mongoengine.EmbeddedDocument)

0 commit comments

Comments
 (0)