|
19 | 19 | import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.PAGE_TOTAL_PARAM_NAME;
|
20 | 20 | import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.QUERY_SELECT_PARAM_NAME;
|
21 | 21 | import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.extractPageArgument;
|
| 22 | +import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.findArgument; |
| 23 | +import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getAliasOrName; |
| 24 | +import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getFields; |
22 | 25 | import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getPageArgument;
|
23 | 26 | import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getSelectionField;
|
24 | 27 | import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.searchByFieldName;
|
25 | 28 |
|
| 29 | +import com.introproventures.graphql.jpa.query.schema.JavaScalars; |
| 30 | +import graphql.GraphQLException; |
26 | 31 | import graphql.language.Argument;
|
| 32 | +import graphql.language.EnumValue; |
27 | 33 | import graphql.language.Field;
|
28 | 34 | import graphql.schema.DataFetcher;
|
29 | 35 | import graphql.schema.DataFetchingEnvironment;
|
| 36 | +import graphql.schema.GraphQLScalarType; |
30 | 37 | import java.util.ArrayList;
|
| 38 | +import java.util.Arrays; |
| 39 | +import java.util.LinkedHashMap; |
31 | 40 | import java.util.List;
|
| 41 | +import java.util.Map; |
32 | 42 | import java.util.Optional;
|
| 43 | +import java.util.stream.Stream; |
33 | 44 | import org.slf4j.Logger;
|
34 | 45 | import org.slf4j.LoggerFactory;
|
35 | 46 |
|
@@ -65,6 +76,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
|
65 | 76 | Optional<Field> pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME);
|
66 | 77 | Optional<Field> totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME);
|
67 | 78 | Optional<Field> recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME);
|
| 79 | + Optional<Field> aggregateSelection = getSelectionField(rootNode, "aggregate"); |
68 | 80 |
|
69 | 81 | final int firstResult = page.getOffset();
|
70 | 82 | final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM
|
@@ -98,9 +110,155 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
|
98 | 110 | pagedResult.withTotal(total);
|
99 | 111 | }
|
100 | 112 |
|
| 113 | + aggregateSelection.ifPresent(aggregateField -> { |
| 114 | + Map<String, Object> aggregate = new LinkedHashMap<>(); |
| 115 | + |
| 116 | + getFields(aggregateField.getSelectionSet(), "count") |
| 117 | + .forEach(countField -> { |
| 118 | + getCountOfArgument(countField) |
| 119 | + .ifPresentOrElse( |
| 120 | + argument -> |
| 121 | + aggregate.put( |
| 122 | + getAliasOrName(countField), |
| 123 | + queryFactory.queryAggregateCount(argument, environment, restrictedKeys) |
| 124 | + ), |
| 125 | + () -> |
| 126 | + aggregate.put( |
| 127 | + getAliasOrName(countField), |
| 128 | + queryFactory.queryTotalCount(environment, restrictedKeys) |
| 129 | + ) |
| 130 | + ); |
| 131 | + }); |
| 132 | + |
| 133 | + getFields(aggregateField.getSelectionSet(), "group") |
| 134 | + .forEach(groupField -> { |
| 135 | + var countField = getFields(groupField.getSelectionSet(), "count") |
| 136 | + .stream() |
| 137 | + .findFirst() |
| 138 | + .orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)); |
| 139 | + |
| 140 | + var countOfArgumentValue = getCountOfArgument(countField); |
| 141 | + |
| 142 | + Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by") |
| 143 | + .stream() |
| 144 | + .map(GraphQLJpaQueryDataFetcher::groupByFieldEntry) |
| 145 | + .toArray(Map.Entry[]::new); |
| 146 | + |
| 147 | + if (groupings.length == 0) { |
| 148 | + throw new GraphQLException("At least one field is required for aggregate group: " + groupField); |
| 149 | + } |
| 150 | + |
| 151 | + var resultList = queryFactory |
| 152 | + .queryAggregateGroupByCount( |
| 153 | + getAliasOrName(countField), |
| 154 | + countOfArgumentValue, |
| 155 | + environment, |
| 156 | + restrictedKeys, |
| 157 | + groupings |
| 158 | + ) |
| 159 | + .stream() |
| 160 | + .peek(map -> |
| 161 | + Stream |
| 162 | + .of(groupings) |
| 163 | + .forEach(group -> { |
| 164 | + var value = map.get(group.getKey()); |
| 165 | + |
| 166 | + Optional |
| 167 | + .ofNullable(value) |
| 168 | + .map(Object::getClass) |
| 169 | + .map(JavaScalars::of) |
| 170 | + .map(GraphQLScalarType::getCoercing) |
| 171 | + .ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value))); |
| 172 | + }) |
| 173 | + ) |
| 174 | + .toList(); |
| 175 | + |
| 176 | + aggregate.put(getAliasOrName(groupField), resultList); |
| 177 | + }); |
| 178 | + |
| 179 | + aggregateField |
| 180 | + .getSelectionSet() |
| 181 | + .getSelections() |
| 182 | + .stream() |
| 183 | + .filter(Field.class::isInstance) |
| 184 | + .map(Field.class::cast) |
| 185 | + .filter(it -> !Arrays.asList("count", "group").contains(it.getName())) |
| 186 | + .forEach(groupField -> { |
| 187 | + var countField = getFields(groupField.getSelectionSet(), "count") |
| 188 | + .stream() |
| 189 | + .findFirst() |
| 190 | + .orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)); |
| 191 | + |
| 192 | + Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by") |
| 193 | + .stream() |
| 194 | + .map(GraphQLJpaQueryDataFetcher::groupByFieldEntry) |
| 195 | + .toArray(Map.Entry[]::new); |
| 196 | + |
| 197 | + if (groupings.length == 0) { |
| 198 | + throw new GraphQLException("At least one field is required for aggregate group: " + groupField); |
| 199 | + } |
| 200 | + |
| 201 | + var resultList = queryFactory |
| 202 | + .queryAggregateGroupByAssociationCount( |
| 203 | + getAliasOrName(countField), |
| 204 | + groupField.getName(), |
| 205 | + environment, |
| 206 | + restrictedKeys, |
| 207 | + groupings |
| 208 | + ) |
| 209 | + .stream() |
| 210 | + .peek(map -> |
| 211 | + Stream |
| 212 | + .of(groupings) |
| 213 | + .forEach(group -> { |
| 214 | + var value = map.get(group.getKey()); |
| 215 | + |
| 216 | + Optional |
| 217 | + .ofNullable(value) |
| 218 | + .map(Object::getClass) |
| 219 | + .map(JavaScalars::of) |
| 220 | + .map(GraphQLScalarType::getCoercing) |
| 221 | + .ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value))); |
| 222 | + }) |
| 223 | + ) |
| 224 | + .toList(); |
| 225 | + |
| 226 | + aggregate.put(getAliasOrName(groupField), resultList); |
| 227 | + }); |
| 228 | + |
| 229 | + pagedResult.withAggregate(aggregate); |
| 230 | + }); |
| 231 | + |
101 | 232 | return pagedResult.build();
|
102 | 233 | }
|
103 | 234 |
|
| 235 | + static Map.Entry<String, String> groupByFieldEntry(Field selectedField) { |
| 236 | + String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName()); |
| 237 | + |
| 238 | + String value = findArgument(selectedField, "field") |
| 239 | + .map(Argument::getValue) |
| 240 | + .map(EnumValue.class::cast) |
| 241 | + .map(EnumValue::getName) |
| 242 | + .orElseThrow(() -> new GraphQLException("group by argument is required.")); |
| 243 | + |
| 244 | + return Map.entry(key, value); |
| 245 | + } |
| 246 | + |
| 247 | + static Map.Entry<String, String> countFieldEntry(Field selectedField) { |
| 248 | + String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName()); |
| 249 | + |
| 250 | + String value = getCountOfArgument(selectedField).orElse(selectedField.getName()); |
| 251 | + |
| 252 | + return Map.entry(key, value); |
| 253 | + } |
| 254 | + |
| 255 | + static Optional<String> getCountOfArgument(Field selectedField) { |
| 256 | + return findArgument(selectedField, "of") |
| 257 | + .map(Argument::getValue) |
| 258 | + .map(EnumValue.class::cast) |
| 259 | + .map(EnumValue::getName); |
| 260 | + } |
| 261 | + |
104 | 262 | public int getDefaultMaxResults() {
|
105 | 263 | return defaultMaxResults;
|
106 | 264 | }
|
|
0 commit comments