Skip to content

Commit 933f4b4

Browse files
authored
Enable subquery in From (#560)
1 parent 7a613de commit 933f4b4

File tree

3 files changed

+105
-44
lines changed

3 files changed

+105
-44
lines changed

partiql-logical-planner/src/lower.rs

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use std::borrow::Cow;
2323

2424
use partiql_value::BindingsName;
2525

26-
use std::collections::{HashMap, HashSet};
26+
use std::collections::HashMap;
2727

2828
use crate::builtins::{FnSymTab, FN_SYM_TAB};
2929
use itertools::Itertools;
@@ -57,12 +57,21 @@ macro_rules! eq_or_fault {
5757

5858
#[macro_export]
5959
macro_rules! true_or_fault {
60+
($self:ident, $expr:expr, $msg:expr) => {
61+
if !$expr {
62+
true_or_fault_err!($self, $expr, $msg);
63+
return partiql_ast::visit::Traverse::Stop;
64+
}
65+
};
66+
}
67+
68+
#[macro_export]
69+
macro_rules! true_or_fault_err {
6070
($self:ident, $expr:expr, $msg:expr) => {
6171
if !$expr {
6272
$self
6373
.errors
6474
.push(AstTransformError::IllegalState($msg.to_string()));
65-
return partiql_ast::visit::Traverse::Stop;
6675
}
6776
};
6877
}
@@ -161,9 +170,6 @@ pub struct AstToLogical<'a> {
161170
path_stack: Vec<Vec<PathComponent>>,
162171
sort_stack: Vec<Vec<logical::SortSpec>>,
163172
aggregate_exprs: Vec<AggregateExpression>,
164-
165-
from_lets: HashSet<NodeId>,
166-
167173
projection_renames: Vec<FnvIndexMap<String, BindingsName<'a>>>,
168174

169175
aliases: FnvIndexMap<NodeId, SymbolPrimitive>,
@@ -173,7 +179,7 @@ pub struct AstToLogical<'a> {
173179
agg_id: IdGenerator,
174180

175181
// output
176-
plan: LogicalPlan<BindingsOp>,
182+
plan_stack: Vec<LogicalPlan<BindingsOp>>,
177183

178184
// catalog & data flow data
179185
key_registry: name_resolver::KeyRegistry,
@@ -233,8 +239,6 @@ impl<'a> AstToLogical<'a> {
233239
sort_stack: Default::default(),
234240
aggregate_exprs: Default::default(),
235241

236-
from_lets: Default::default(),
237-
238242
projection_renames: Default::default(),
239243

240244
aliases: Default::default(),
@@ -244,7 +248,7 @@ impl<'a> AstToLogical<'a> {
244248
agg_id: Default::default(),
245249

246250
// output
247-
plan: Default::default(),
251+
plan_stack: Default::default(),
248252

249253
key_registry: registry,
250254
fnsym_tab,
@@ -258,13 +262,19 @@ impl<'a> AstToLogical<'a> {
258262
mut self,
259263
query: &ast::AstNode<ast::TopLevelQuery>,
260264
) -> Result<logical::LogicalPlan<logical::BindingsOp>, AstTransformationError> {
265+
self.enter_plan();
261266
query.visit(&mut self);
267+
true_or_fault_err!(
268+
self,
269+
self.plan_stack.len() == 1,
270+
"self.plan_stack.len() != 1"
271+
);
262272
if !self.errors.is_empty() {
263273
return Err(AstTransformationError {
264274
errors: self.errors,
265275
});
266276
}
267-
Ok(self.plan)
277+
Ok(self.plan_stack.pop().unwrap())
268278
}
269279

270280
#[inline]
@@ -509,6 +519,21 @@ impl<'a> AstToLogical<'a> {
509519
self.q_stack.last_mut().unwrap()
510520
}
511521

522+
#[inline]
523+
fn enter_plan(&mut self) {
524+
self.plan_stack.push(LogicalPlan::default());
525+
}
526+
527+
#[inline]
528+
fn exit_plan(&mut self) -> LogicalPlan<BindingsOp> {
529+
self.plan_stack.pop().expect("environment level")
530+
}
531+
532+
#[inline]
533+
fn curr_plan(&mut self) -> &mut LogicalPlan<BindingsOp> {
534+
self.plan_stack.last_mut().expect("plan")
535+
}
536+
512537
#[inline]
513538
fn enter_benv(&mut self) {
514539
self.bexpr_stack.push(vec![]);
@@ -691,8 +716,8 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
691716
let mut benv = self.exit_benv();
692717
eq_or_fault!(self, benv.len(), 1, "Expect benv.len() == 1");
693718
let out = benv.pop().unwrap();
694-
let sink_id = self.plan.add_operator(BindingsOp::Sink);
695-
self.plan.add_flow(out, sink_id);
719+
let sink_id = self.curr_plan().add_operator(BindingsOp::Sink);
720+
self.curr_plan().add_flow(out, sink_id);
696721
Traverse::Continue
697722
}
698723

@@ -712,7 +737,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
712737
let mut clauses = clauses.evaluation_order().into_iter();
713738
if let Some(mut src_id) = clauses.next() {
714739
for dst_id in clauses {
715-
self.plan.add_flow(src_id, dst_id);
740+
self.curr_plan().add_flow(src_id, dst_id);
716741
src_id = dst_id;
717742
}
718743
self.push_bexpr(src_id);
@@ -726,7 +751,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
726751
);
727752
let mut out = *benv.first().unwrap();
728753
benv.into_iter().skip(1).for_each(|op| {
729-
self.plan.add_flow(out, op);
754+
self.curr_plan().add_flow(out, op);
730755
out = op;
731756
});
732757
self.push_bexpr(out);
@@ -759,7 +784,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
759784

760785
match query_set {
761786
QuerySet::BagOp(bag_op) => {
762-
eq_or_fault!(self, benv.len(), 2, "benv.len() != 2");
787+
eq_or_fault!(self, benv.len(), 2, "qs benv.len() != 2");
763788
let rid = benv.pop().unwrap();
764789
let lid = benv.pop().unwrap();
765790

@@ -777,20 +802,20 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
777802
None => logical::SetQuantifier::Distinct,
778803
};
779804

780-
let id = self.plan.add_operator(BindingsOp::BagOp(BagOp {
805+
let id = self.curr_plan().add_operator(BindingsOp::BagOp(BagOp {
781806
bag_op: bag_operator,
782807
setq,
783808
}));
784-
self.plan.add_flow_with_branch_num(lid, id, 0);
785-
self.plan.add_flow_with_branch_num(rid, id, 1);
809+
self.curr_plan().add_flow_with_branch_num(lid, id, 0);
810+
self.curr_plan().add_flow_with_branch_num(rid, id, 1);
786811
self.push_bexpr(id);
787812
}
788813
QuerySet::Select(_) => {}
789814
QuerySet::Expr(_) => {
790815
eq_or_fault!(self, env.len(), 1, "env.len() != 1");
791816
let expr = env.into_iter().next().unwrap();
792817
let op = BindingsOp::ExprQuery(logical::ExprQuery { expr });
793-
let id = self.plan.add_operator(op);
818+
let id = self.curr_plan().add_operator(op);
794819
self.push_bexpr(id);
795820
}
796821
QuerySet::Values(_) => {
@@ -839,7 +864,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
839864
aggregate_exprs: self.aggregate_exprs.clone(),
840865
group_as_alias: None,
841866
});
842-
let id = self.plan.add_operator(group_by);
867+
let id = self.curr_plan().add_operator(group_by);
843868
self.current_clauses_mut().group_by_clause.replace(id);
844869
}
845870
Traverse::Continue
@@ -859,7 +884,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
859884
eq_or_fault!(self, env.len(), 0, "env.len() != 0");
860885

861886
if let Some(SetQuantifier::Distinct) = projection.setq {
862-
let id = self.plan.add_operator(BindingsOp::Distinct);
887+
let id = self.curr_plan().add_operator(BindingsOp::Distinct);
863888
self.current_clauses_mut().distinct.replace(id);
864889
}
865890
Traverse::Continue
@@ -934,7 +959,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
934959
logical::BindingsOp::ProjectValue(logical::ProjectValue { expr })
935960
}
936961
};
937-
let id = self.plan.add_operator(select);
962+
let id = self.curr_plan().add_operator(select);
938963
self.current_clauses_mut().select_clause.replace(id);
939964
Traverse::Continue
940965
}
@@ -1468,8 +1493,9 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
14681493
}
14691494

14701495
fn enter_from_let(&mut self, from_let: &'ast FromLet) -> Traverse {
1471-
self.from_lets.insert(*self.current_node());
14721496
*self.current_ctx_mut() = QueryContext::FromLet;
1497+
self.enter_plan();
1498+
self.enter_benv();
14731499
self.enter_env();
14741500

14751501
let id = *self.current_node();
@@ -1485,10 +1511,20 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
14851511

14861512
fn exit_from_let(&mut self, from_let: &'ast FromLet) -> Traverse {
14871513
*self.current_ctx_mut() = QueryContext::Query;
1514+
let subplan = self.exit_plan();
1515+
let benv = self.exit_benv();
14881516
let mut env = self.exit_env();
1489-
eq_or_fault!(self, env.len(), 1, "env.len() != 1");
1517+
eq_or_fault!(self, env.len() + benv.len(), 1, "env.len()+benv.len() != 1");
14901518

1491-
let expr = env.pop().unwrap();
1519+
let expr = if !benv.is_empty() {
1520+
// Subquery in From Let
1521+
let subq = logical::SubQueryExpr { plan: subplan };
1522+
ValueExpr::SubQueryExpr(subq)
1523+
} else {
1524+
// Expression in From Let
1525+
self.curr_plan().merge_plan(subplan); // merge in subplan, as there is no subquery
1526+
env.pop().unwrap()
1527+
};
14921528

14931529
let FromLet {
14941530
kind,
@@ -1527,10 +1563,13 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
15271563
ProjectAllMode::Unwrap,
15281564
),
15291565
};
1530-
let id = self.plan.add_operator(bexpr);
1566+
1567+
let id = self.curr_plan().add_operator(bexpr);
15311568
self.push_bexpr(id);
1569+
15321570
if let Some(select_id) = self.current_clauses_mut().select_clause {
1533-
if let Some(BindingsOp::ProjectAll(mode)) = self.plan.operator_as_mut(select_id) {
1571+
if let Some(BindingsOp::ProjectAll(mode)) = self.curr_plan().operator_as_mut(select_id)
1572+
{
15341573
*mode = project_all_mode
15351574
}
15361575
}
@@ -1546,7 +1585,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
15461585

15471586
fn exit_join(&mut self, join: &'ast Join) -> Traverse {
15481587
let mut benv = self.exit_benv();
1549-
eq_or_fault!(self, benv.len(), 2, "benv.len() != 2");
1588+
eq_or_fault!(self, benv.len(), 2, "j benv.len() != 2");
15501589

15511590
let mut env = self.exit_env();
15521591
true_or_fault!(
@@ -1569,17 +1608,17 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
15691608

15701609
let rid = benv.pop().unwrap();
15711610
let lid = benv.pop().unwrap();
1572-
let left = Box::new(self.plan.operator(lid).unwrap().clone());
1573-
let right = Box::new(self.plan.operator(rid).unwrap().clone());
1611+
let left = Box::new(self.curr_plan().operator(lid).unwrap().clone());
1612+
let right = Box::new(self.curr_plan().operator(rid).unwrap().clone());
15741613
let join = logical::BindingsOp::Join(logical::Join {
15751614
kind,
15761615
left,
15771616
right,
15781617
on,
15791618
});
1580-
let join = self.plan.add_operator(join);
1581-
self.plan.add_flow_with_branch_num(lid, join, 0);
1582-
self.plan.add_flow_with_branch_num(rid, join, 1);
1619+
let join = self.curr_plan().add_operator(join);
1620+
self.curr_plan().add_flow_with_branch_num(lid, join, 0);
1621+
self.curr_plan().add_flow_with_branch_num(rid, join, 1);
15831622
self.push_bexpr(join);
15841623
Traverse::Continue
15851624
}
@@ -1611,7 +1650,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
16111650
let filter = logical::BindingsOp::Filter(logical::Filter {
16121651
expr: env.pop().unwrap(),
16131652
});
1614-
let id = self.plan.add_operator(filter);
1653+
let id = self.curr_plan().add_operator(filter);
16151654

16161655
self.current_clauses_mut().where_clause.replace(id);
16171656
Traverse::Continue
@@ -1629,7 +1668,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
16291668
let having = BindingsOp::Having(logical::Having {
16301669
expr: env.pop().unwrap(),
16311670
});
1632-
let id = self.plan.add_operator(having);
1671+
let id = self.curr_plan().add_operator(having);
16331672

16341673
self.current_clauses_mut().having_clause.replace(id);
16351674
Traverse::Continue
@@ -1679,8 +1718,9 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
16791718
));
16801719
return Traverse::Stop;
16811720
}
1721+
let mut errors = Vec::default();
16821722
let select_clause = self
1683-
.plan
1723+
.curr_plan()
16841724
.operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None"))
16851725
.unwrap();
16861726
let mut binding = Vec::new();
@@ -1705,14 +1745,14 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
17051745
logical::Lit::String(s) => s.clone(),
17061746
_ => {
17071747
// Report error but allow visitor to continue
1708-
self.errors.push(AstTransformError::IllegalState(
1748+
errors.push(AstTransformError::IllegalState(
17091749
"Unexpected literal type".to_string(),
17101750
));
17111751
String::new()
17121752
}
17131753
},
17141754
_ => {
1715-
self.errors.push(AstTransformError::IllegalState(
1755+
errors.push(AstTransformError::IllegalState(
17161756
"Unexpected alias type".to_string(),
17171757
));
17181758
return Traverse::Stop;
@@ -1728,14 +1768,16 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
17281768
exprs.insert(alias, value);
17291769
}
17301770

1771+
self.errors.extend(errors);
1772+
17311773
let group_by: BindingsOp = BindingsOp::GroupBy(logical::GroupBy {
17321774
strategy,
17331775
exprs,
17341776
aggregate_exprs,
17351777
group_as_alias,
17361778
});
17371779

1738-
let id = self.plan.add_operator(group_by);
1780+
let id = self.curr_plan().add_operator(group_by);
17391781
self.current_clauses_mut().group_by_clause.replace(id);
17401782
Traverse::Continue
17411783
}
@@ -1763,7 +1805,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
17631805
fn exit_order_by_expr(&mut self, _order_by_expr: &'ast OrderByExpr) -> Traverse {
17641806
let specs = self.exit_sort();
17651807
let order_by = logical::BindingsOp::OrderBy(logical::OrderBy { specs });
1766-
let id = self.plan.add_operator(order_by);
1808+
let id = self.curr_plan().add_operator(order_by);
17671809
if matches!(self.current_ctx(), Some(QueryContext::Query)) {
17681810
self.current_clauses_mut().order_by_clause.replace(id);
17691811
} else {
@@ -1836,7 +1878,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> {
18361878
};
18371879

18381880
let limit_offset = logical::BindingsOp::LimitOffset(logical::LimitOffset { limit, offset });
1839-
let id = self.plan.add_operator(limit_offset);
1881+
let id = self.curr_plan().add_operator(limit_offset);
18401882
if matches!(self.current_ctx(), Some(QueryContext::Query)) {
18411883
self.current_clauses_mut().limit_offset_clause.replace(id);
18421884
} else {

0 commit comments

Comments
 (0)