3434import graphql .schema .DataFetcher ;
3535import graphql .schema .DataFetchingEnvironment ;
3636import graphql .schema .GraphQLScalarType ;
37- import java .util .ArrayList ;
37+ import java .util .Collection ;
3838import java .util .LinkedHashMap ;
3939import java .util .List ;
4040import java .util .Map ;
5353class GraphQLJpaQueryDataFetcher implements DataFetcher <PagedResult <Object >> {
5454
5555 private static final Logger logger = LoggerFactory .getLogger (GraphQLJpaQueryDataFetcher .class );
56+ public static final String AGGREGATE_PARAM_NAME = "aggregate" ;
57+ public static final String COUNT_FIELD_NAME = "count" ;
58+ public static final String GROUP_FIELD_NAME = "group" ;
59+ public static final String BY_FILED_NAME = "by" ;
60+ public static final String FIELD_ARGUMENT_NAME = "field" ;
61+ public static final String OF_ARGUMENT_NAME = "of" ;
5662
5763 private final int defaultMaxResults ;
5864 private final int defaultPageLimitSize ;
@@ -76,7 +82,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
7682 Optional <Field > pagesSelection = getSelectionField (rootNode , PAGE_PAGES_PARAM_NAME );
7783 Optional <Field > totalSelection = getSelectionField (rootNode , PAGE_TOTAL_PARAM_NAME );
7884 Optional <Field > recordsSelection = searchByFieldName (rootNode , QUERY_SELECT_PARAM_NAME );
79- Optional <Field > aggregateSelection = getSelectionField (rootNode , "aggregate" );
85+ Optional <Field > aggregateSelection = getSelectionField (rootNode , AGGREGATE_PARAM_NAME );
8086
8187 final int firstResult = page .getOffset ();
8288 final int maxResults = Integer .min (page .getLimit (), defaultMaxResults ); // Limit max results to avoid OoM
@@ -85,35 +91,47 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
8591 .builder ()
8692 .withOffset (firstResult )
8793 .withLimit (maxResults );
88- Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
94+
95+ final Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
8996
9097 if (recordsSelection .isPresent ()) {
9198 if (restrictedKeys .isPresent ()) {
92- final List <Object > queryKeys = new ArrayList <>();
93-
9499 if (pageArgument .isPresent () || enableDefaultMaxResults ) {
95- queryKeys .addAll (
96- queryFactory .queryKeys (environment , firstResult , maxResults , restrictedKeys .get ())
100+ final List <Object > queryKeys = queryFactory .queryKeys (
101+ environment ,
102+ firstResult ,
103+ maxResults ,
104+ restrictedKeys .get ()
97105 );
106+
107+ if (!queryKeys .isEmpty ()) {
108+ pagedResult .withSelect (
109+ queryFactory .queryResultList (environment , maxResults , restrictedKeys .get ())
110+ );
111+ } else {
112+ pagedResult .withSelect (List .of ());
113+ }
98114 } else {
99- queryKeys . addAll ( restrictedKeys .get ());
115+ pagedResult . withSelect ( queryFactory . queryResultList ( environment , maxResults , restrictedKeys .get () ));
100116 }
101-
102- final List <Object > resultList = queryFactory .queryResultList (environment , maxResults , queryKeys );
103- pagedResult .withSelect (resultList );
104117 }
105118 }
106119
107120 if (totalSelection .isPresent () || pagesSelection .isPresent ()) {
108- final Long total = queryFactory .queryTotalCount (environment , restrictedKeys );
121+ final var selectResult = pagedResult .getSelect ();
122+
123+ final long total = recordsSelection .isEmpty () ||
124+ selectResult .filter (Predicate .not (Collection ::isEmpty )).isPresent ()
125+ ? queryFactory .queryTotalCount (environment , restrictedKeys )
126+ : 0L ;
109127
110128 pagedResult .withTotal (total );
111129 }
112130
113131 aggregateSelection .ifPresent (aggregateField -> {
114132 Map <String , Object > aggregate = new LinkedHashMap <>();
115133
116- getFields (aggregateField .getSelectionSet (), "count" )
134+ getFields (aggregateField .getSelectionSet (), COUNT_FIELD_NAME )
117135 .forEach (countField -> {
118136 getCountOfArgument (countField )
119137 .ifPresentOrElse (
@@ -130,16 +148,16 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
130148 );
131149 });
132150
133- getFields (aggregateField .getSelectionSet (), "group" )
151+ getFields (aggregateField .getSelectionSet (), GROUP_FIELD_NAME )
134152 .forEach (groupField -> {
135- var countField = getFields (groupField .getSelectionSet (), "count" )
153+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
136154 .stream ()
137155 .findFirst ()
138156 .orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField ));
139157
140158 var countOfArgumentValue = getCountOfArgument (countField );
141159
142- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
160+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
143161 .stream ()
144162 .map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
145163 .toArray (Map .Entry []::new );
@@ -176,21 +194,21 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
176194 aggregate .put (getAliasOrName (groupField ), resultList );
177195 });
178196
179- getSelectionField (aggregateField , "by" )
197+ getSelectionField (aggregateField , BY_FILED_NAME )
180198 .map (byField -> byField .getSelectionSet ().getSelections ().stream ().map (Field .class ::cast ).toList ())
181199 .filter (Predicate .not (List ::isEmpty ))
182200 .ifPresent (aggregateBySelections -> {
183201 var aggregatesBy = new LinkedHashMap <>();
184- aggregate .put ("by" , aggregatesBy );
202+ aggregate .put (BY_FILED_NAME , aggregatesBy );
185203
186204 aggregateBySelections .forEach (groupField -> {
187- var countField = getFields (groupField .getSelectionSet (), "count" )
205+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
188206 .stream ()
189207 .findFirst ()
190208 .orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField )
191209 );
192210
193- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
211+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
194212 .stream ()
195213 .map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
196214 .toArray (Map .Entry []::new );
@@ -239,7 +257,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
239257 static Map .Entry <String , String > groupByFieldEntry (Field selectedField ) {
240258 String key = Optional .ofNullable (selectedField .getAlias ()).orElse (selectedField .getName ());
241259
242- String value = findArgument (selectedField , "field" )
260+ String value = findArgument (selectedField , FIELD_ARGUMENT_NAME )
243261 .map (Argument ::getValue )
244262 .map (EnumValue .class ::cast )
245263 .map (EnumValue ::getName )
@@ -257,7 +275,7 @@ static Map.Entry<String, String> countFieldEntry(Field selectedField) {
257275 }
258276
259277 static Optional <String > getCountOfArgument (Field selectedField ) {
260- return findArgument (selectedField , "of" )
278+ return findArgument (selectedField , OF_ARGUMENT_NAME )
261279 .map (Argument ::getValue )
262280 .map (EnumValue .class ::cast )
263281 .map (EnumValue ::getName );
0 commit comments