Skip to content

Commit bb72bbb

Browse files
arunsureshkumararun-sureshkumar
authored andcommitted
Graphene Federation v2 Support Added
2 parents 889eda8 + 4a44e2a commit bb72bbb

File tree

6 files changed

+198
-80
lines changed

6 files changed

+198
-80
lines changed

graphene_mongo/advanced_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ class FileFieldType(graphene.ObjectType):
99
length = graphene.Int()
1010
data = graphene.String()
1111

12+
# Support Graphene Federation v2
13+
_shareable = True
14+
1215
@classmethod
1316
def _resolve_fs_field(cls, field, name, default_value=None):
1417
v = getattr(field.instance, field.key)
@@ -37,6 +40,9 @@ def resolve_data(self, info):
3740
class _CoordinatesTypeField(graphene.ObjectType):
3841
type = graphene.String()
3942

43+
# Support Graphene Federation v2
44+
_shareable = True
45+
4046
def resolve_type(self, info):
4147
return self["type"]
4248

graphene_mongo/fields.py

Lines changed: 143 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33
from collections import OrderedDict
44
from functools import partial, reduce
55

6+
import bson
67
import graphene
78
import mongoengine
89
from bson import DBRef, ObjectId
910
from graphene import Context
10-
from graphene.types.utils import get_type
11-
from graphene.utils.str_converters import to_snake_case
12-
from graphql import GraphQLResolveInfo
13-
from mongoengine.base import get_document
14-
from promise import Promise
15-
from graphql_relay import from_global_id
1611
from graphene.relay import ConnectionField
1712
from graphene.types.argument import to_arguments
1813
from graphene.types.dynamic import Dynamic
1914
from graphene.types.structures import Structure
20-
from graphql_relay.connection.array_connection import cursor_to_offset
15+
from graphene.types.utils import get_type
16+
from graphene.utils.str_converters import to_snake_case
17+
from graphql import GraphQLResolveInfo
18+
from graphql_relay import from_global_id
19+
from graphql_relay.connection.arrayconnection import cursor_to_offset
2120
from mongoengine import QuerySet
21+
from mongoengine.base import get_document
22+
from promise import Promise
23+
from pymongo.errors import OperationFailure
2224

2325
from .advanced_types import (
2426
FileFieldType,
@@ -30,6 +32,9 @@
3032
from .registry import get_global_registry
3133
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
3234
connection_from_iterables
35+
import pymongo
36+
37+
PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
3338

3439

3540
class MongoengineConnectionField(ConnectionField):
@@ -77,9 +82,27 @@ def registry(self):
7782

7883
@property
7984
def args(self):
85+
_field_args = self.field_args
86+
_advance_args = self.advance_args
87+
_filter_args = self.filter_args
88+
_extended_args = self.extended_args
89+
if self._type._meta.non_filter_fields:
90+
for _field in self._type._meta.non_filter_fields:
91+
if _field in _field_args:
92+
_field_args.pop(_field)
93+
if _field in _advance_args:
94+
_advance_args.pop(_field)
95+
if _field in _filter_args:
96+
_filter_args.pop(_field)
97+
if _field in _extended_args:
98+
_filter_args.pop(_field)
99+
extra_args = dict(dict(dict(_field_args, **_advance_args), **_filter_args), **_extended_args)
100+
101+
for key in list(self._base_args.keys()):
102+
extra_args.pop(key, None)
80103
return to_arguments(
81104
self._base_args or OrderedDict(),
82-
dict(dict(dict(self.field_args, **self.advance_args), **self.filter_args), **self.extended_args),
105+
extra_args
83106
)
84107

85108
@args.setter
@@ -100,6 +123,14 @@ def is_filterable(k):
100123
return False
101124
if not hasattr(self.model, k):
102125
return False
126+
else:
127+
# else section is a patch for federated field error
128+
field_ = self.fields[k]
129+
type_ = field_.type
130+
while hasattr(type_, "of_type"):
131+
type_ = type_.of_type
132+
if hasattr(type_, "_sdl") and "@key" in type_._sdl:
133+
return False
103134
if isinstance(getattr(self.model, k), property):
104135
return False
105136
try:
@@ -128,6 +159,9 @@ def is_filterable(k):
128159
getattr(converted, "_of_type", None), graphene.Union
129160
):
130161
return False
162+
# below if condition: workaround for DB filterable field redefined as custom graphene type
163+
if hasattr(field_, 'type') and hasattr(converted, 'type') and converted.type != field_.type:
164+
return False
131165
return True
132166

133167
def get_filter_type(_type):
@@ -150,7 +184,7 @@ def filter_args(self):
150184
if self._type._meta.filter_fields:
151185
for field, filter_collection in self._type._meta.filter_fields.items():
152186
for each in filter_collection:
153-
if str(self._type._meta.fields[field].type) == 'PointFieldType':
187+
if str(self._type._meta.fields[field].type) in ('PointFieldType', 'PointFieldType!'):
154188
if each == 'max_distance':
155189
filter_type = graphene.Int
156190
else:
@@ -279,17 +313,17 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
279313
skip)
280314
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
281315

282-
def default_resolver(self, _root, info, required_fields=None, **args):
316+
def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
283317
if required_fields is None:
284318
required_fields = list()
285319
args = args or {}
286320
for key, value in dict(args).items():
287321
if value is None:
288322
del args[key]
289-
if _root is not None:
323+
if _root is not None and not resolved:
290324
field_name = to_snake_case(info.field_name)
291325
if not hasattr(_root, "_fields_ordered"):
292-
if getattr(_root, field_name, []) is not None:
326+
if isinstance(getattr(_root, field_name, []), list):
293327
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
294328
elif field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field,
295329
mongoengine.EmbeddedDocumentField) or
@@ -316,25 +350,33 @@ def default_resolver(self, _root, info, required_fields=None, **args):
316350
before = args.pop("before", None)
317351
if before:
318352
before = cursor_to_offset(before)
319-
if callable(getattr(self.model, "objects", None)):
320-
if "pk__in" in args and args["pk__in"]:
321-
count = len(args["pk__in"])
322-
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
323-
count=count)
324-
if limit:
325-
if reverse:
326-
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
327-
else:
328-
args["pk__in"] = args["pk__in"][skip:skip + limit]
329-
elif skip:
330-
args["pk__in"] = args["pk__in"][skip:]
331-
iterables = self.get_queryset(self.model, info, required_fields, **args)
332-
list_length = len(iterables)
333-
if isinstance(info, GraphQLResolveInfo):
334-
if not info.context:
335-
info = info._replace(context=Context())
336-
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
337-
elif _root is None or args:
353+
354+
if resolved is not None:
355+
items = resolved
356+
357+
if isinstance(items, QuerySet):
358+
try:
359+
count = items.count(with_limit_and_skip=True)
360+
except OperationFailure:
361+
count = len(items)
362+
else:
363+
count = len(items)
364+
365+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
366+
count=count)
367+
368+
if limit:
369+
if reverse:
370+
items = items[::-1][skip:skip + limit]
371+
else:
372+
items = items[skip:skip + limit]
373+
elif skip:
374+
items = items[skip:]
375+
iterables = items
376+
list_length = len(iterables)
377+
378+
elif callable(getattr(self.model, "objects", None)):
379+
if _root is None or args or isinstance(getattr(_root, field_name, []), MongoengineConnectionField):
338380
args_copy = args.copy()
339381
for key in args.copy():
340382
if key not in self.model._fields_ordered:
@@ -346,8 +388,20 @@ def default_resolver(self, _root, info, required_fields=None, **args):
346388
mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key),
347389
mongoengine.fields.CachedReferenceField):
348390
if not isinstance(args_copy[key], ObjectId):
349-
args_copy[key] = from_global_id(args_copy[key])[1]
350-
count = mongoengine.get_db()[self.model._get_collection_name()].count_documents(args_copy)
391+
_from_global_id = from_global_id(args_copy[key])[1]
392+
if bson.objectid.ObjectId.is_valid(_from_global_id):
393+
args_copy[key] = ObjectId(_from_global_id)
394+
else:
395+
args_copy[key] = _from_global_id
396+
elif isinstance(getattr(self.model, key),
397+
mongoengine.fields.EnumField):
398+
if getattr(args_copy[key], "value", None):
399+
args_copy[key] = args_copy[key].value
400+
401+
if PYMONGO_VERSION >= (3, 7):
402+
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
403+
else:
404+
count = self.model.objects(args_copy).count()
351405
if count != 0:
352406
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
353407
count=count)
@@ -358,6 +412,24 @@ def default_resolver(self, _root, info, required_fields=None, **args):
358412
info = info._replace(context=Context())
359413
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
360414

415+
elif "pk__in" in args and args["pk__in"]:
416+
count = len(args["pk__in"])
417+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
418+
count=count)
419+
if limit:
420+
if reverse:
421+
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
422+
else:
423+
args["pk__in"] = args["pk__in"][skip:skip + limit]
424+
elif skip:
425+
args["pk__in"] = args["pk__in"][skip:]
426+
iterables = self.get_queryset(self.model, info, required_fields, **args)
427+
list_length = len(iterables)
428+
if isinstance(info, GraphQLResolveInfo):
429+
if not info.context:
430+
info = info._replace(context=Context())
431+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
432+
361433
elif _root is not None:
362434
field_name = to_snake_case(info.field_name)
363435
items = getattr(_root, field_name, [])
@@ -373,6 +445,7 @@ def default_resolver(self, _root, info, required_fields=None, **args):
373445
items = items[skip:]
374446
iterables = items
375447
list_length = len(iterables)
448+
376449
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
377450
has_previous_page = True if skip else False
378451
if reverse:
@@ -391,31 +464,42 @@ def default_resolver(self, _root, info, required_fields=None, **args):
391464
return connection
392465

393466
def chained_resolver(self, resolver, is_partial, root, info, **args):
467+
394468
for key, value in dict(args).items():
395469
if value is None:
396470
del args[key]
471+
397472
required_fields = list()
473+
398474
for field in self.required_fields:
399475
if field in self.model._fields_ordered:
400476
required_fields.append(field)
477+
401478
for field in get_query_fields(info):
402479
if to_snake_case(field) in self.model._fields_ordered:
403480
required_fields.append(to_snake_case(field))
481+
404482
args_copy = args.copy()
483+
405484
if not bool(args) or not is_partial:
406485
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
407486
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
408487

488+
from itertools import filterfalse
489+
connection_fields = [field for field in self.fields if
490+
type(self.fields[field]) == MongoengineConnectionField]
491+
filterable_args = tuple(filterfalse(connection_fields.__contains__, list(self.model._fields_ordered)))
409492
for arg_name, arg in args.copy().items():
410-
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
493+
if arg_name not in filterable_args + tuple(self.filter_args.keys()):
411494
args_copy.pop(arg_name)
412495
if isinstance(info, GraphQLResolveInfo):
413496
if not info.context:
414497
info = info._replace(context=Context())
415-
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
498+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)
416499

417500
# XXX: Filter nested args
418501
resolved = resolver(root, info, **args)
502+
419503
if resolved is not None:
420504
if isinstance(resolved, list):
421505
if resolved == list():
@@ -428,36 +512,55 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
428512
args.update(resolved._query)
429513
args_copy = args.copy()
430514
for arg_name, arg in args.copy().items():
431-
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
432-
self.filter_args.keys()):
515+
if "." in arg_name or arg_name not in self.model._fields_ordered + (
516+
'first', 'last', 'before', 'after') + tuple(
517+
self.filter_args.keys()):
433518
args_copy.pop(arg_name)
434519
if arg_name == '_id' and isinstance(arg, dict):
435520
operation = list(arg.keys())[0]
436521
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
437522
if not isinstance(arg, ObjectId) and '.' in arg_name:
438-
operation = list(arg.keys())[0]
439-
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
523+
if type(arg) == dict:
524+
operation = list(arg.keys())[0]
525+
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[
526+
operation]
527+
else:
528+
args_copy[arg_name.replace('.', '__')] = arg
529+
elif '.' in arg_name and isinstance(arg, ObjectId):
530+
args_copy[arg_name.replace('.', '__')] = arg
440531
else:
441532
operations = ["$lte", "$gte", "$ne", "$in"]
442533
if isinstance(arg, dict) and any(op in arg for op in operations):
443534
operation = list(arg.keys())[0]
444535
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
445536
del args_copy[arg_name]
446-
return self.default_resolver(root, info, required_fields, **args_copy)
537+
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
447538
elif isinstance(resolved, Promise):
448539
return resolved.value
449540
else:
450541
return resolved
542+
451543
return self.default_resolver(root, info, required_fields, **args)
452544

453545
@classmethod
454546
def connection_resolver(cls, resolver, connection_type, root, info, **args):
547+
if root:
548+
for key, value in root.__dict__.items():
549+
if value:
550+
try:
551+
setattr(root, key, from_global_id(value)[1])
552+
except Exception as error:
553+
pass
455554
iterable = resolver(root, info, **args)
555+
456556
if isinstance(connection_type, graphene.NonNull):
457557
connection_type = connection_type.of_type
558+
458559
on_resolve = partial(cls.resolve_connection, connection_type, args)
560+
459561
if Promise.is_thenable(iterable):
460562
return Promise.resolve(iterable).then(on_resolve)
563+
461564
return on_resolve(iterable)
462565

463566
def get_resolver(self, parent_resolver):

graphene_mongo/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ def register_enum(self, cls):
2929
assert type(cls) == EnumMeta, 'Only EnumMeta can be registered, received "{}"'.format(
3030
cls.__name__
3131
)
32+
if not cls.__name__.endswith('Enum'):
33+
name = cls.__name__ + 'Enum'
34+
else:
35+
name = cls.__name__
36+
cls.__name__ = name
3237
self._registry_enum[cls] = Enum.from_enum(cls)
3338

3439
def get_type_for_model(self, model):

0 commit comments

Comments
 (0)