@@ -7,7 +7,9 @@ use pgt_treesitter_queries::{
77 queries:: { self , QueryResult } ,
88} ;
99
10- use crate :: sanitization:: SanitizedCompletionParams ;
10+ use crate :: {
11+ NodeText , context:: policy_parser:: PolicyParser , sanitization:: SanitizedCompletionParams ,
12+ } ;
1113
1214#[ derive( Debug , PartialEq , Eq ) ]
1315pub enum WrappingClause < ' a > {
@@ -19,12 +21,8 @@ pub enum WrappingClause<'a> {
1921 } ,
2022 Update ,
2123 Delete ,
22- }
23-
24- #[ derive( PartialEq , Eq , Debug ) ]
25- pub ( crate ) enum NodeText < ' a > {
26- Replaced ,
27- Original ( & ' a str ) ,
24+ PolicyName ,
25+ ToRole ,
2826}
2927
3028/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
@@ -45,7 +43,7 @@ pub enum WrappingNode {
4543pub ( crate ) enum NodeUnderCursor < ' a > {
4644 TsNode ( tree_sitter:: Node < ' a > ) ,
4745 CustomNode {
48- text : NodeText < ' a > ,
46+ text : NodeText ,
4947 range : TextRange ,
5048 kind : String ,
5149 } ,
@@ -172,14 +170,35 @@ impl<'a> CompletionContext<'a> {
172170 // policy handling is important to Supabase, but they are a PostgreSQL specific extension,
173171 // so the tree_sitter_sql language does not support it.
174172 // We infer the context manually.
175- // if params.text.to_lowercase().starts_with("create policy")
176- // || params.text.to_lowercase().starts_with("alter policy")
177- // || params.text.to_lowercase().starts_with("drop policy")
178- // {
179- // } else {
180- ctx. gather_tree_context ( ) ;
181- ctx. gather_info_from_ts_queries ( ) ;
182- // }
173+ if params. text . to_lowercase ( ) . starts_with ( "create policy" )
174+ || params. text . to_lowercase ( ) . starts_with ( "alter policy" )
175+ || params. text . to_lowercase ( ) . starts_with ( "drop policy" )
176+ {
177+ let policy_context = PolicyParser :: get_context ( & ctx. text , ctx. position ) ;
178+
179+ ctx. node_under_cursor = Some ( NodeUnderCursor :: CustomNode {
180+ text : policy_context. node_text . into ( ) ,
181+ range : policy_context. node_range ,
182+ kind : policy_context. node_kind . clone ( ) ,
183+ } ) ;
184+
185+ if policy_context. table_name . is_some ( ) {
186+ let mut new = HashSet :: new ( ) ;
187+ new. insert ( policy_context. table_name . unwrap ( ) ) ;
188+ ctx. mentioned_relations
189+ . insert ( policy_context. schema_name , new) ;
190+ }
191+
192+ ctx. wrapping_clause_type = match policy_context. node_kind . as_str ( ) {
193+ "policy_name" => Some ( WrappingClause :: PolicyName ) ,
194+ "policy_role" => Some ( WrappingClause :: ToRole ) ,
195+ "policy_table" => Some ( WrappingClause :: From ) ,
196+ _ => None ,
197+ } ;
198+ } else {
199+ ctx. gather_tree_context ( ) ;
200+ ctx. gather_info_from_ts_queries ( ) ;
201+ }
183202
184203 tracing:: warn!( "sql: {}" , ctx. text) ;
185204 tracing:: warn!( "position: {}" , ctx. position) ;
@@ -237,13 +256,13 @@ impl<'a> CompletionContext<'a> {
237256 }
238257 }
239258
240- fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText < ' a > > {
259+ fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText > {
241260 let source = self . text ;
242261 ts_node. utf8_text ( source. as_bytes ( ) ) . ok ( ) . map ( |txt| {
243262 if SanitizedCompletionParams :: is_sanitized_token ( txt) {
244263 NodeText :: Replaced
245264 } else {
246- NodeText :: Original ( txt)
265+ NodeText :: Original ( txt. into ( ) )
247266 }
248267 } )
249268 }
@@ -386,7 +405,7 @@ impl<'a> CompletionContext<'a> {
386405 NodeText :: Original ( txt) => Some ( txt) ,
387406 NodeText :: Replaced => None ,
388407 } ) {
389- match txt {
408+ match txt. as_str ( ) {
390409 "where" => return Some ( WrappingClause :: Where ) ,
391410 "update" => return Some ( WrappingClause :: Update ) ,
392411 "select" => return Some ( WrappingClause :: Select ) ,
@@ -436,7 +455,8 @@ impl<'a> CompletionContext<'a> {
436455#[ cfg( test) ]
437456mod tests {
438457 use crate :: {
439- context:: { CompletionContext , NodeText , WrappingClause } ,
458+ NodeText ,
459+ context:: { CompletionContext , WrappingClause } ,
440460 sanitization:: SanitizedCompletionParams ,
441461 test_helper:: { CURSOR_POS , get_text_and_position} ,
442462 } ;
@@ -607,7 +627,7 @@ mod tests {
607627 NodeUnderCursor :: TsNode ( node) => {
608628 assert_eq ! (
609629 ctx. get_ts_node_content( node) ,
610- Some ( NodeText :: Original ( "select" ) )
630+ Some ( NodeText :: Original ( "select" . into ( ) ) )
611631 ) ;
612632
613633 assert_eq ! (
@@ -643,7 +663,7 @@ mod tests {
643663 NodeUnderCursor :: TsNode ( node) => {
644664 assert_eq ! (
645665 ctx. get_ts_node_content( & node) ,
646- Some ( NodeText :: Original ( "from" ) )
666+ Some ( NodeText :: Original ( "from" . into ( ) ) )
647667 ) ;
648668 }
649669 _ => unreachable ! ( ) ,
@@ -671,7 +691,10 @@ mod tests {
671691
672692 match node {
673693 NodeUnderCursor :: TsNode ( node) => {
674- assert_eq ! ( ctx. get_ts_node_content( & node) , Some ( NodeText :: Original ( "" ) ) ) ;
694+ assert_eq ! (
695+ ctx. get_ts_node_content( & node) ,
696+ Some ( NodeText :: Original ( "" . into( ) ) )
697+ ) ;
675698 assert_eq ! ( ctx. wrapping_clause_type, None ) ;
676699 }
677700 _ => unreachable ! ( ) ,
@@ -703,7 +726,7 @@ mod tests {
703726 NodeUnderCursor :: TsNode ( node) => {
704727 assert_eq ! (
705728 ctx. get_ts_node_content( & node) ,
706- Some ( NodeText :: Original ( "fro" ) )
729+ Some ( NodeText :: Original ( "fro" . into ( ) ) )
707730 ) ;
708731 assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
709732 }
0 commit comments