3
3
from collections import OrderedDict
4
4
from functools import partial , reduce
5
5
6
+ import bson
6
7
import graphene
7
8
import mongoengine
8
9
from bson import DBRef , ObjectId
9
10
from graphene import Context
10
- from graphene .types .utils import get_type
11
- from graphene .utils .str_converters import to_snake_case
12
- from graphql import GraphQLResolveInfo
13
- from mongoengine .base import get_document
14
- from promise import Promise
15
- from graphql_relay import from_global_id
16
11
from graphene .relay import ConnectionField
17
12
from graphene .types .argument import to_arguments
18
13
from graphene .types .dynamic import Dynamic
19
14
from graphene .types .structures import Structure
20
- from graphql_relay .connection .array_connection import cursor_to_offset
15
+ from graphene .types .utils import get_type
16
+ from graphene .utils .str_converters import to_snake_case
17
+ from graphql import GraphQLResolveInfo
18
+ from graphql_relay import from_global_id
19
+ from graphql_relay .connection .arrayconnection import cursor_to_offset
21
20
from mongoengine import QuerySet
21
+ from mongoengine .base import get_document
22
+ from promise import Promise
23
+ from pymongo .errors import OperationFailure
22
24
23
25
from .advanced_types import (
24
26
FileFieldType ,
30
32
from .registry import get_global_registry
31
33
from .utils import get_model_reference_fields , get_query_fields , find_skip_and_limit , \
32
34
connection_from_iterables
35
+ import pymongo
36
+
37
+ PYMONGO_VERSION = tuple (pymongo .version_tuple [:2 ])
33
38
34
39
35
40
class MongoengineConnectionField (ConnectionField ):
@@ -77,9 +82,27 @@ def registry(self):
77
82
78
83
@property
79
84
def args (self ):
85
+ _field_args = self .field_args
86
+ _advance_args = self .advance_args
87
+ _filter_args = self .filter_args
88
+ _extended_args = self .extended_args
89
+ if self ._type ._meta .non_filter_fields :
90
+ for _field in self ._type ._meta .non_filter_fields :
91
+ if _field in _field_args :
92
+ _field_args .pop (_field )
93
+ if _field in _advance_args :
94
+ _advance_args .pop (_field )
95
+ if _field in _filter_args :
96
+ _filter_args .pop (_field )
97
+ if _field in _extended_args :
98
+ _filter_args .pop (_field )
99
+ extra_args = dict (dict (dict (_field_args , ** _advance_args ), ** _filter_args ), ** _extended_args )
100
+
101
+ for key in list (self ._base_args .keys ()):
102
+ extra_args .pop (key , None )
80
103
return to_arguments (
81
104
self ._base_args or OrderedDict (),
82
- dict ( dict ( dict ( self . field_args , ** self . advance_args ), ** self . filter_args ), ** self . extended_args ),
105
+ extra_args
83
106
)
84
107
85
108
@args .setter
@@ -100,6 +123,14 @@ def is_filterable(k):
100
123
return False
101
124
if not hasattr (self .model , k ):
102
125
return False
126
+ else :
127
+ # else section is a patch for federated field error
128
+ field_ = self .fields [k ]
129
+ type_ = field_ .type
130
+ while hasattr (type_ , "of_type" ):
131
+ type_ = type_ .of_type
132
+ if hasattr (type_ , "_sdl" ) and "@key" in type_ ._sdl :
133
+ return False
103
134
if isinstance (getattr (self .model , k ), property ):
104
135
return False
105
136
try :
@@ -128,6 +159,9 @@ def is_filterable(k):
128
159
getattr (converted , "_of_type" , None ), graphene .Union
129
160
):
130
161
return False
162
+ # below if condition: workaround for DB filterable field redefined as custom graphene type
163
+ if hasattr (field_ , 'type' ) and hasattr (converted , 'type' ) and converted .type != field_ .type :
164
+ return False
131
165
return True
132
166
133
167
def get_filter_type (_type ):
@@ -150,7 +184,7 @@ def filter_args(self):
150
184
if self ._type ._meta .filter_fields :
151
185
for field , filter_collection in self ._type ._meta .filter_fields .items ():
152
186
for each in filter_collection :
153
- if str (self ._type ._meta .fields [field ].type ) == 'PointFieldType' :
187
+ if str (self ._type ._meta .fields [field ].type ) in ( 'PointFieldType' , 'PointFieldType!' ) :
154
188
if each == 'max_distance' :
155
189
filter_type = graphene .Int
156
190
else :
@@ -279,17 +313,17 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
279
313
skip )
280
314
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
281
315
282
- def default_resolver (self , _root , info , required_fields = None , ** args ):
316
+ def default_resolver (self , _root , info , required_fields = None , resolved = None , ** args ):
283
317
if required_fields is None :
284
318
required_fields = list ()
285
319
args = args or {}
286
320
for key , value in dict (args ).items ():
287
321
if value is None :
288
322
del args [key ]
289
- if _root is not None :
323
+ if _root is not None and not resolved :
290
324
field_name = to_snake_case (info .field_name )
291
325
if not hasattr (_root , "_fields_ordered" ):
292
- if getattr (_root , field_name , []) is not None :
326
+ if isinstance ( getattr (_root , field_name , []), list ) :
293
327
args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
294
328
elif field_name in _root ._fields_ordered and not (isinstance (_root ._fields [field_name ].field ,
295
329
mongoengine .EmbeddedDocumentField ) or
@@ -316,25 +350,33 @@ def default_resolver(self, _root, info, required_fields=None, **args):
316
350
before = args .pop ("before" , None )
317
351
if before :
318
352
before = cursor_to_offset (before )
319
- if callable (getattr (self .model , "objects" , None )):
320
- if "pk__in" in args and args ["pk__in" ]:
321
- count = len (args ["pk__in" ])
322
- skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
323
- count = count )
324
- if limit :
325
- if reverse :
326
- args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
327
- else :
328
- args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
329
- elif skip :
330
- args ["pk__in" ] = args ["pk__in" ][skip :]
331
- iterables = self .get_queryset (self .model , info , required_fields , ** args )
332
- list_length = len (iterables )
333
- if isinstance (info , GraphQLResolveInfo ):
334
- if not info .context :
335
- info = info ._replace (context = Context ())
336
- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
337
- elif _root is None or args :
353
+
354
+ if resolved is not None :
355
+ items = resolved
356
+
357
+ if isinstance (items , QuerySet ):
358
+ try :
359
+ count = items .count (with_limit_and_skip = True )
360
+ except OperationFailure :
361
+ count = len (items )
362
+ else :
363
+ count = len (items )
364
+
365
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
366
+ count = count )
367
+
368
+ if limit :
369
+ if reverse :
370
+ items = items [::- 1 ][skip :skip + limit ]
371
+ else :
372
+ items = items [skip :skip + limit ]
373
+ elif skip :
374
+ items = items [skip :]
375
+ iterables = items
376
+ list_length = len (iterables )
377
+
378
+ elif callable (getattr (self .model , "objects" , None )):
379
+ if _root is None or args or isinstance (getattr (_root , field_name , []), MongoengineConnectionField ):
338
380
args_copy = args .copy ()
339
381
for key in args .copy ():
340
382
if key not in self .model ._fields_ordered :
@@ -346,8 +388,20 @@ def default_resolver(self, _root, info, required_fields=None, **args):
346
388
mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
347
389
mongoengine .fields .CachedReferenceField ):
348
390
if not isinstance (args_copy [key ], ObjectId ):
349
- args_copy [key ] = from_global_id (args_copy [key ])[1 ]
350
- count = mongoengine .get_db ()[self .model ._get_collection_name ()].count_documents (args_copy )
391
+ _from_global_id = from_global_id (args_copy [key ])[1 ]
392
+ if bson .objectid .ObjectId .is_valid (_from_global_id ):
393
+ args_copy [key ] = ObjectId (_from_global_id )
394
+ else :
395
+ args_copy [key ] = _from_global_id
396
+ elif isinstance (getattr (self .model , key ),
397
+ mongoengine .fields .EnumField ):
398
+ if getattr (args_copy [key ], "value" , None ):
399
+ args_copy [key ] = args_copy [key ].value
400
+
401
+ if PYMONGO_VERSION >= (3 , 7 ):
402
+ count = (mongoengine .get_db ()[self .model ._get_collection_name ()]).count_documents (args_copy )
403
+ else :
404
+ count = self .model .objects (args_copy ).count ()
351
405
if count != 0 :
352
406
skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
353
407
count = count )
@@ -358,6 +412,24 @@ def default_resolver(self, _root, info, required_fields=None, **args):
358
412
info = info ._replace (context = Context ())
359
413
info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
360
414
415
+ elif "pk__in" in args and args ["pk__in" ]:
416
+ count = len (args ["pk__in" ])
417
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
418
+ count = count )
419
+ if limit :
420
+ if reverse :
421
+ args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
422
+ else :
423
+ args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
424
+ elif skip :
425
+ args ["pk__in" ] = args ["pk__in" ][skip :]
426
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
427
+ list_length = len (iterables )
428
+ if isinstance (info , GraphQLResolveInfo ):
429
+ if not info .context :
430
+ info = info ._replace (context = Context ())
431
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
432
+
361
433
elif _root is not None :
362
434
field_name = to_snake_case (info .field_name )
363
435
items = getattr (_root , field_name , [])
@@ -373,6 +445,7 @@ def default_resolver(self, _root, info, required_fields=None, **args):
373
445
items = items [skip :]
374
446
iterables = items
375
447
list_length = len (iterables )
448
+
376
449
has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
377
450
has_previous_page = True if skip else False
378
451
if reverse :
@@ -391,31 +464,42 @@ def default_resolver(self, _root, info, required_fields=None, **args):
391
464
return connection
392
465
393
466
def chained_resolver (self , resolver , is_partial , root , info , ** args ):
467
+
394
468
for key , value in dict (args ).items ():
395
469
if value is None :
396
470
del args [key ]
471
+
397
472
required_fields = list ()
473
+
398
474
for field in self .required_fields :
399
475
if field in self .model ._fields_ordered :
400
476
required_fields .append (field )
477
+
401
478
for field in get_query_fields (info ):
402
479
if to_snake_case (field ) in self .model ._fields_ordered :
403
480
required_fields .append (to_snake_case (field ))
481
+
404
482
args_copy = args .copy ()
483
+
405
484
if not bool (args ) or not is_partial :
406
485
if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
407
486
mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
408
487
488
+ from itertools import filterfalse
489
+ connection_fields = [field for field in self .fields if
490
+ type (self .fields [field ]) == MongoengineConnectionField ]
491
+ filterable_args = tuple (filterfalse (connection_fields .__contains__ , list (self .model ._fields_ordered )))
409
492
for arg_name , arg in args .copy ().items ():
410
- if arg_name not in self . model . _fields_ordered + tuple (self .filter_args .keys ()):
493
+ if arg_name not in filterable_args + tuple (self .filter_args .keys ()):
411
494
args_copy .pop (arg_name )
412
495
if isinstance (info , GraphQLResolveInfo ):
413
496
if not info .context :
414
497
info = info ._replace (context = Context ())
415
- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
498
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args_copy )
416
499
417
500
# XXX: Filter nested args
418
501
resolved = resolver (root , info , ** args )
502
+
419
503
if resolved is not None :
420
504
if isinstance (resolved , list ):
421
505
if resolved == list ():
@@ -428,36 +512,55 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
428
512
args .update (resolved ._query )
429
513
args_copy = args .copy ()
430
514
for arg_name , arg in args .copy ().items ():
431
- if arg_name not in self .model ._fields_ordered + ('first' , 'last' , 'before' , 'after' ) + tuple (
432
- self .filter_args .keys ()):
515
+ if "." in arg_name or arg_name not in self .model ._fields_ordered + (
516
+ 'first' , 'last' , 'before' , 'after' ) + tuple (
517
+ self .filter_args .keys ()):
433
518
args_copy .pop (arg_name )
434
519
if arg_name == '_id' and isinstance (arg , dict ):
435
520
operation = list (arg .keys ())[0 ]
436
521
args_copy ['pk' + operation .replace ('$' , '__' )] = arg [operation ]
437
522
if not isinstance (arg , ObjectId ) and '.' in arg_name :
438
- operation = list (arg .keys ())[0 ]
439
- args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [operation ]
523
+ if type (arg ) == dict :
524
+ operation = list (arg .keys ())[0 ]
525
+ args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [
526
+ operation ]
527
+ else :
528
+ args_copy [arg_name .replace ('.' , '__' )] = arg
529
+ elif '.' in arg_name and isinstance (arg , ObjectId ):
530
+ args_copy [arg_name .replace ('.' , '__' )] = arg
440
531
else :
441
532
operations = ["$lte" , "$gte" , "$ne" , "$in" ]
442
533
if isinstance (arg , dict ) and any (op in arg for op in operations ):
443
534
operation = list (arg .keys ())[0 ]
444
535
args_copy [arg_name + operation .replace ('$' , '__' )] = arg [operation ]
445
536
del args_copy [arg_name ]
446
- return self .default_resolver (root , info , required_fields , ** args_copy )
537
+ return self .default_resolver (root , info , required_fields , resolved = resolved , ** args_copy )
447
538
elif isinstance (resolved , Promise ):
448
539
return resolved .value
449
540
else :
450
541
return resolved
542
+
451
543
return self .default_resolver (root , info , required_fields , ** args )
452
544
453
545
@classmethod
454
546
def connection_resolver (cls , resolver , connection_type , root , info , ** args ):
547
+ if root :
548
+ for key , value in root .__dict__ .items ():
549
+ if value :
550
+ try :
551
+ setattr (root , key , from_global_id (value )[1 ])
552
+ except Exception as error :
553
+ pass
455
554
iterable = resolver (root , info , ** args )
555
+
456
556
if isinstance (connection_type , graphene .NonNull ):
457
557
connection_type = connection_type .of_type
558
+
458
559
on_resolve = partial (cls .resolve_connection , connection_type , args )
560
+
459
561
if Promise .is_thenable (iterable ):
460
562
return Promise .resolve (iterable ).then (on_resolve )
563
+
461
564
return on_resolve (iterable )
462
565
463
566
def get_resolver (self , parent_resolver ):
0 commit comments