11use graph:: prelude:: BlockNumber ;
22use graph:: schema:: AggregationInterval ;
33use sqlparser:: ast:: {
4- Expr , FunctionArg , FunctionArgExpr , Ident , LimitClause , ObjectName , ObjectNamePart , Offset ,
5- Query , SetExpr , Statement , TableAlias , TableFactor , TableFunctionArgs , Value , ValueWithSpan ,
6- VisitMut , VisitorMut ,
4+ Cte , Expr , FunctionArg , FunctionArgExpr , Ident , LimitClause , ObjectName , ObjectNamePart ,
5+ Offset , Query , SetExpr , Statement , TableAlias , TableFactor , TableFunctionArgs , Value ,
6+ ValueWithSpan , VisitMut , VisitorMut ,
77} ;
88use sqlparser:: parser:: Parser ;
99use std:: result:: Result ;
@@ -36,11 +36,56 @@ pub enum Error {
3636 UnsupportedOffset ( u32 , u32 ) ,
3737 #[ error( "Qualified table names are not supported: {0}" ) ]
3838 NoQualifiedTables ( String ) ,
39+ #[ error( "Internal error: {0}" ) ]
40+ InternalError ( String ) ,
41+ }
42+
43+ /// Helper to track CTEs introduced by the main query or subqueries. Every
44+ /// time we enter a query, we need to track a new set of CTEs which must be
45+ /// discarded once we are done with that query. Otherwise, we might allow
46+ /// access to forbidden tables with a query like `select *, (with pg_user as
47+ /// (select 1) select 1) as one from pg_user`
48+ #[ derive( Default ) ]
49+ struct CteStack {
50+ stack : Vec < HashSet < String > > ,
51+ }
52+
53+ impl CteStack {
54+ fn enter_query ( & mut self ) {
55+ self . stack . push ( HashSet :: new ( ) ) ;
56+ }
57+
58+ fn exit_query ( & mut self ) {
59+ self . stack . pop ( ) ;
60+ }
61+
62+ fn contains ( & self , name : & str ) -> bool {
63+ for entry in self . stack . iter ( ) . rev ( ) {
64+ if entry. contains ( & name. to_lowercase ( ) ) {
65+ return true ;
66+ }
67+ }
68+ false
69+ }
70+
71+ fn clear ( & mut self ) {
72+ self . stack . clear ( ) ;
73+ }
74+
75+ fn add_ctes ( & mut self , ctes : & [ Cte ] ) -> ControlFlow < Error > {
76+ let Some ( entry) = self . stack . last_mut ( ) else {
77+ return ControlFlow :: Break ( Error :: InternalError ( "CTE stack is empty" . into ( ) ) ) ;
78+ } ;
79+ for cte in ctes {
80+ entry. insert ( cte. alias . name . value . to_lowercase ( ) ) ;
81+ }
82+ ControlFlow :: Continue ( ( ) )
83+ }
3984}
4085
4186pub struct Validator < ' a > {
4287 layout : & ' a Layout ,
43- ctes : HashSet < String > ,
88+ ctes : CteStack ,
4489 block : BlockNumber ,
4590 max_limit : u32 ,
4691 max_offset : u32 ,
@@ -156,12 +201,9 @@ impl VisitorMut for Validator<'_> {
156201
157202 fn pre_visit_query ( & mut self , query : & mut Query ) -> ControlFlow < Self :: Break > {
158203 // Add common table expressions to the set of known tables
204+ self . ctes . enter_query ( ) ;
159205 if let Some ( ref with) = query. with {
160- self . ctes . extend (
161- with. cte_tables
162- . iter ( )
163- . map ( |cte| cte. alias . name . value . to_lowercase ( ) ) ,
164- ) ;
206+ self . ctes . add_ctes ( & with. cte_tables ) ?;
165207 }
166208
167209 match * query. body {
@@ -177,6 +219,11 @@ impl VisitorMut for Validator<'_> {
177219 self . validate_limit_offset ( query)
178220 }
179221
222+ fn post_visit_query ( & mut self , _query : & mut Query ) -> ControlFlow < Self :: Break > {
223+ self . ctes . exit_query ( ) ;
224+ ControlFlow :: Continue ( ( ) )
225+ }
226+
180227 /// Invoked for any table function in the AST.
181228 /// See [TableFactor::Table.args](sqlparser::ast::TableFactor::Table::args) for more details identifying a table function
182229 fn post_visit_table_factor (
0 commit comments