Skip to content

Commit 56254c2

Browse files
committed
feat: Implement basic query optimizer with predicate pushdown
1 parent f2691f8 commit 56254c2

File tree

6 files changed

+277
-0
lines changed

6 files changed

+277
-0
lines changed

glint/src/main/java/co/clflushopt/glint/query/logical/plan/Scan.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ public List<String> getProjections() {
3232
return projections;
3333
}
3434

35+
public String getPath() {
36+
return path;
37+
}
38+
3539
@Override
3640
public Schema getSchema() {
3741
return schema;
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package co.clflushopt.glint.query.optimizer;
2+
3+
import java.util.HashSet;
4+
import java.util.List;
5+
import java.util.Set;
6+
7+
import co.clflushopt.glint.query.logical.expr.AggregateExpr;
8+
import co.clflushopt.glint.query.logical.expr.AliasExpr;
9+
import co.clflushopt.glint.query.logical.expr.BinaryExpr;
10+
import co.clflushopt.glint.query.logical.expr.CastExpr;
11+
import co.clflushopt.glint.query.logical.expr.ColumnExpr;
12+
import co.clflushopt.glint.query.logical.expr.ColumnIndex;
13+
import co.clflushopt.glint.query.logical.expr.LogicalExpr;
14+
import co.clflushopt.glint.query.logical.plan.LogicalPlan;
15+
16+
/**
17+
* The column extractor extracts nbamed columns from a logical plan.
18+
*
19+
* ColumnExtractor
20+
*/
21+
public class ColumnExtractor {
22+
23+
/**
24+
* Extracts all the named columns from the logical plan.
25+
*
26+
* @param plan
27+
* @param expressions
28+
* @return
29+
*/
30+
public static Set<String> extractColumns(LogicalPlan plan, List<LogicalExpr> expressions) {
31+
Set<String> columns = new HashSet<>();
32+
for (LogicalExpr expression : expressions) {
33+
columns.addAll(extractColumns(plan, expression));
34+
}
35+
return columns;
36+
}
37+
38+
/**
39+
* Extracts the named columns from the logical plan.
40+
*
41+
* @param plan the logical plan.
42+
* @return the named columns.
43+
*/
44+
public static Set<String> extractColumns(LogicalPlan plan, LogicalExpr expression) {
45+
Set<String> columns = new HashSet<>();
46+
if (expression instanceof ColumnExpr) {
47+
columns.add(((ColumnExpr) expression).getName());
48+
}
49+
if (expression instanceof ColumnIndex) {
50+
// Extract the column name using the index and the logical plan schema.
51+
ColumnIndex columnIndex = (ColumnIndex) expression;
52+
columns.add(plan.getSchema().getFields().get(columnIndex.getIndex()).name());
53+
}
54+
if (expression instanceof AggregateExpr) {
55+
columns.addAll(extractColumns(plan, ((AggregateExpr) expression).getExpr()));
56+
}
57+
if (expression instanceof BinaryExpr) {
58+
columns.addAll(extractColumns(plan, ((BinaryExpr) expression).getLhs()));
59+
columns.addAll(extractColumns(plan, ((BinaryExpr) expression).getRhs()));
60+
}
61+
if (expression instanceof AliasExpr) {
62+
columns.addAll(extractColumns(plan, ((AliasExpr) expression).getExpr()));
63+
}
64+
if (expression instanceof CastExpr) {
65+
columns.addAll(extractColumns(plan, ((CastExpr) expression).getExpr()));
66+
}
67+
68+
return columns;
69+
}
70+
71+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package co.clflushopt.glint.query.optimizer;
2+
3+
import co.clflushopt.glint.query.logical.plan.LogicalPlan;
4+
5+
/**
6+
* An optimizer rule is an interface that allows chaining and applying rules to
7+
* a query plan.
8+
*
9+
* OptimizerRule
10+
*/
11+
public interface OptimizerRule {
12+
13+
/**
14+
* Apply the rule to the query plan.
15+
*
16+
* @param plan the query plan to apply the rule to.
17+
* @return the optimized query plan.
18+
*/
19+
public LogicalPlan apply(LogicalPlan plan);
20+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package co.clflushopt.glint.query.optimizer;
2+
3+
import java.util.HashSet;
4+
import java.util.stream.Collectors;
5+
6+
import co.clflushopt.glint.query.logical.plan.Aggregate;
7+
import co.clflushopt.glint.query.logical.plan.LogicalPlan;
8+
import co.clflushopt.glint.query.logical.plan.Projection;
9+
import co.clflushopt.glint.query.logical.plan.Scan;
10+
import co.clflushopt.glint.query.logical.plan.Selection;
11+
12+
public class PredicatePushdownRule implements OptimizerRule {
13+
14+
@Override
15+
public LogicalPlan apply(LogicalPlan plan) {
16+
return pushdown(plan, new HashSet<>());
17+
}
18+
19+
private static LogicalPlan pushdown(LogicalPlan plan, HashSet<String> columns) {
20+
if (plan instanceof Projection) {
21+
var projection = (Projection) plan;
22+
columns.addAll(ColumnExtractor.extractColumns(plan, ((Projection) plan).getExpr()));
23+
var input = pushdown(projection.getInput(), columns);
24+
return new Projection(input, ((Projection) plan).getExpr());
25+
}
26+
if (plan instanceof Selection) {
27+
var selection = (Selection) plan;
28+
var newColumns = new HashSet<>(columns);
29+
newColumns.addAll(ColumnExtractor.extractColumns(plan, selection.getExpr()));
30+
var input = pushdown(selection.getInput(), newColumns);
31+
return new Selection(input, selection.getExpr());
32+
}
33+
if (plan instanceof Aggregate) {
34+
var aggregate = (Aggregate) plan;
35+
var newColumns = new HashSet<>(columns);
36+
newColumns.addAll(ColumnExtractor.extractColumns(plan, aggregate.getGroupExpr()));
37+
newColumns.addAll(ColumnExtractor.extractColumns(plan,
38+
aggregate.getAggregateExpr().stream().map(e -> e.getExpr()).toList()));
39+
var input = pushdown(aggregate.getInput(), newColumns);
40+
return new Aggregate(input, aggregate.getGroupExpr(), aggregate.getAggregateExpr());
41+
}
42+
if (plan instanceof Scan) {
43+
var scanPlan = (Scan) plan;
44+
var fieldNames = ((Scan) plan).getDataSource().getSchema().getFields().stream()
45+
.map(f -> f.name()).collect(Collectors.toSet());
46+
var pushdownColumns = fieldNames.stream().filter(columns::contains)
47+
.collect(Collectors.toList());
48+
return new Scan(scanPlan.getPath(), scanPlan.getDataSource(), pushdownColumns);
49+
}
50+
return plan;
51+
}
52+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package co.clflushopt.glint.query.optimizer;
2+
3+
import co.clflushopt.glint.query.logical.plan.LogicalPlan;
4+
5+
/**
6+
* The query optimizer is responsible for optimizing the query plan at the
7+
* logical level.
8+
*
9+
* QueryOptimizer
10+
*/
11+
public class QueryOptimizer {
12+
13+
/**
14+
* Optimizes the logical plan by applying all rules in the optimizer.
15+
*
16+
* @param plan
17+
* @return
18+
*/
19+
public static LogicalPlan optimize(LogicalPlan plan) {
20+
return plan;
21+
}
22+
23+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package co.clflushopt.glint.query.optimizer;
2+
3+
import static org.junit.Assert.assertEquals;
4+
5+
import java.io.FileNotFoundException;
6+
import java.util.Arrays;
7+
import java.util.Collections;
8+
import java.util.List;
9+
import java.util.Optional;
10+
11+
import org.junit.Test;
12+
13+
import co.clflushopt.glint.dataframe.DataFrame;
14+
import co.clflushopt.glint.dataframe.DataFrameImpl;
15+
import co.clflushopt.glint.datasource.CsvDataSource;
16+
import co.clflushopt.glint.query.logical.expr.AggregateExpr;
17+
import co.clflushopt.glint.query.logical.expr.BooleanExpr;
18+
import co.clflushopt.glint.query.logical.expr.ColumnExpr;
19+
import co.clflushopt.glint.query.logical.expr.LiteralString;
20+
import co.clflushopt.glint.query.logical.expr.LogicalExpr;
21+
import co.clflushopt.glint.query.logical.plan.LogicalPlan;
22+
import co.clflushopt.glint.query.logical.plan.Scan;
23+
import co.clflushopt.glint.types.ArrowTypes;
24+
import co.clflushopt.glint.types.Field;
25+
import co.clflushopt.glint.types.Schema;
26+
27+
public class QueryOptimizerTest {
28+
29+
@Test
30+
public void testProjectionPushDown() throws FileNotFoundException {
31+
DataFrame df = csv().project(Arrays.asList(col("id"), col("first_name"), col("last_name")));
32+
33+
PredicatePushdownRule rule = new PredicatePushdownRule();
34+
LogicalPlan optimizedPlan = rule.apply(df.getLogicalPlan());
35+
36+
String expected = "Projection: #id, #first_name, #last_name\n"
37+
+ "\tScan:employee [projection=(last_name, id, first_name)]\n";
38+
39+
assertEquals(expected, LogicalPlan.format(optimizedPlan));
40+
}
41+
42+
@Test
43+
public void testProjectionPushDownWithSelection() throws FileNotFoundException {
44+
DataFrame df = csv().filter(eq(col("state"), lit("CO")))
45+
.project(Arrays.asList(col("id"), col("first_name"), col("last_name")));
46+
47+
PredicatePushdownRule rule = new PredicatePushdownRule();
48+
LogicalPlan optimizedPlan = rule.apply(df.getLogicalPlan());
49+
50+
String expected = "Projection: #id, #first_name, #last_name\n" + "\tFilter: #state = 'CO'\n"
51+
+ "\t\tScan:employee [projection=(last_name, id, state, first_name)]\n";
52+
53+
assertEquals(expected, LogicalPlan.format(optimizedPlan));
54+
}
55+
56+
@Test
57+
public void testProjectionPushDownWithAggregateQuery() throws FileNotFoundException {
58+
DataFrame df = csv().aggregate(Collections.singletonList(col("state")),
59+
List.of(min(col("salary")), max(col("salary")), count(col("salary"))));
60+
61+
PredicatePushdownRule rule = new PredicatePushdownRule();
62+
LogicalPlan optimizedPlan = rule.apply(df.getLogicalPlan());
63+
64+
String expected = "Aggregate: groupExpr=[#state], aggregateExpr=[MIN(#salary), MAX(#salary), COUNT(#salary)]\n"
65+
+ "\tScan:employee [projection=(state, salary)]\n";
66+
67+
assertEquals(expected, LogicalPlan.format(optimizedPlan));
68+
}
69+
70+
private DataFrame csv() throws FileNotFoundException {
71+
String employeeCsv = "../testdata/employee.csv";
72+
Schema schema = new Schema(Arrays.asList(new Field("id", ArrowTypes.Int64Type),
73+
new Field("first_name", ArrowTypes.StringType),
74+
new Field("last_name", ArrowTypes.StringType),
75+
new Field("state", ArrowTypes.StringType),
76+
new Field("job_title", ArrowTypes.StringType),
77+
new Field("salary", ArrowTypes.Int64Type)));
78+
return new DataFrameImpl(new Scan("employee",
79+
new CsvDataSource(employeeCsv, Optional.of(schema), true, 1024),
80+
Collections.emptyList()));
81+
}
82+
83+
// Helper methods for creating expressions
84+
private LogicalExpr col(String name) {
85+
return new ColumnExpr(name);
86+
}
87+
88+
private LogicalExpr lit(String value) {
89+
return new LiteralString(value);
90+
}
91+
92+
private LogicalExpr eq(LogicalExpr left, LogicalExpr right) {
93+
return BooleanExpr.Eq(left, right);
94+
}
95+
96+
private AggregateExpr min(LogicalExpr expr) {
97+
return new AggregateExpr.Min(expr);
98+
}
99+
100+
private AggregateExpr max(LogicalExpr expr) {
101+
return new AggregateExpr.Max(expr);
102+
}
103+
104+
private AggregateExpr count(LogicalExpr expr) {
105+
return new AggregateExpr.Count(expr);
106+
}
107+
}

0 commit comments

Comments
 (0)