@@ -19,7 +19,7 @@ struct WordWithIndex {
1919}
2020
2121impl WordWithIndex {
22- fn under_cursor ( & self , cursor_pos : usize ) -> bool {
22+ fn is_under_cursor ( & self , cursor_pos : usize ) -> bool {
2323 self . start <= cursor_pos && self . end > cursor_pos
2424 }
2525
@@ -30,54 +30,58 @@ impl WordWithIndex {
3030 }
3131}
3232
33+ /// Note: A policy name within quotation marks will be considered a single word.
3334fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
3435 let mut words = vec ! [ ] ;
3536
36- let mut start : Option < usize > = None ;
37+ let mut start_of_word : Option < usize > = None ;
3738 let mut current_word = String :: new ( ) ;
3839 let mut in_quotation_marks = false ;
3940
40- for ( pos , c ) in sql. char_indices ( ) {
41- if ( c . is_ascii_whitespace ( ) || c == ';' )
41+ for ( current_position , current_char ) in sql. char_indices ( ) {
42+ if ( current_char . is_ascii_whitespace ( ) || current_char == ';' )
4243 && !current_word. is_empty ( )
43- && start . is_some ( )
44+ && start_of_word . is_some ( )
4445 && !in_quotation_marks
4546 {
4647 words. push ( WordWithIndex {
4748 word : current_word,
48- start : start . unwrap ( ) ,
49- end : pos ,
49+ start : start_of_word . unwrap ( ) ,
50+ end : current_position ,
5051 } ) ;
52+
5153 current_word = String :: new ( ) ;
52- start = None ;
53- } else if ( c. is_ascii_whitespace ( ) || c == ';' ) && current_word. is_empty ( ) {
54+ start_of_word = None ;
55+ } else if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
56+ && current_word. is_empty ( )
57+ {
5458 // do nothing
55- } else if c == '"' && start . is_none ( ) {
59+ } else if current_char == '"' && start_of_word . is_none ( ) {
5660 in_quotation_marks = true ;
57- start = Some ( pos ) ;
58- current_word . push ( c ) ;
59- } else if c == '"' && start . is_some ( ) {
60- current_word. push ( c ) ;
61+ current_word . push ( current_char ) ;
62+ start_of_word = Some ( current_position ) ;
63+ } else if current_char == '"' && start_of_word . is_some ( ) {
64+ current_word. push ( current_char ) ;
6165 words. push ( WordWithIndex {
6266 word : current_word,
63- start : start . unwrap ( ) ,
64- end : pos + 1 ,
67+ start : start_of_word . unwrap ( ) ,
68+ end : current_position + 1 ,
6569 } ) ;
6670 in_quotation_marks = false ;
67- start = None ;
71+ start_of_word = None ;
6872 current_word = String :: new ( )
69- } else if start . is_some ( ) {
70- current_word. push ( c )
73+ } else if start_of_word . is_some ( ) {
74+ current_word. push ( current_char )
7175 } else {
72- start = Some ( pos ) ;
73- current_word. push ( c ) ;
76+ start_of_word = Some ( current_position ) ;
77+ current_word. push ( current_char ) ;
7478 }
7579 }
7680
77- if !current_word. is_empty ( ) && start . is_some ( ) {
81+ if !current_word. is_empty ( ) && start_of_word . is_some ( ) {
7882 words. push ( WordWithIndex {
7983 word : current_word,
80- start : start . unwrap ( ) ,
84+ start : start_of_word . unwrap ( ) ,
8185 end : sql. len ( ) ,
8286 } ) ;
8387 }
@@ -100,6 +104,10 @@ pub(crate) struct PolicyContext {
100104 pub node_kind : String ,
101105}
102106
107+ /// Simple parser that'll turn a policy-related statement into a context object required for
108+ /// completions.
109+ /// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`.
110+ /// It can only parse policy statements.
103111pub ( crate ) struct PolicyParser {
104112 tokens : Peekable < std:: vec:: IntoIter < WordWithIndex > > ,
105113 previous_token : Option < WordWithIndex > ,
@@ -136,7 +144,7 @@ impl PolicyParser {
136144
137145 fn parse ( mut self ) -> PolicyContext {
138146 while let Some ( token) = self . advance ( ) {
139- if token. under_cursor ( self . cursor_position ) {
147+ if token. is_under_cursor ( self . cursor_position ) {
140148 self . handle_token_under_cursor ( token) ;
141149 } else {
142150 self . handle_token ( token) ;
@@ -161,9 +169,8 @@ impl PolicyParser {
161169 }
162170 "on" => {
163171 if token. word . contains ( '.' ) {
164- let mut parts = token . word . split ( '.' ) ;
172+ let ( schema_name , table_name ) = self . schema_and_table_name ( & token ) ;
165173
166- let schema_name: String = parts. next ( ) . unwrap ( ) . into ( ) ;
167174 let schema_name_len = schema_name. len ( ) ;
168175 self . context . schema_name = Some ( schema_name) ;
169176
@@ -176,8 +183,16 @@ impl PolicyParser {
176183 . expect ( "Text too long" ) ;
177184
178185 self . context . node_range = range_without_schema;
179- self . context . node_text = parts. next ( ) . unwrap ( ) . into ( ) ;
180186 self . context . node_kind = "policy_table" . into ( ) ;
187+
188+ self . context . node_text = match table_name {
189+ Some ( node_text) => node_text,
190+
191+ // In practice, this should never happen.
192+ // The completion sanitization will add a word after a `.` if nothing follows it;
193+ // the token_text will then look like `schema.REPLACED_TOKEN`.
194+ None => String :: new ( ) ,
195+ } ;
181196 } else {
182197 self . context . node_range = token. get_range ( ) ;
183198 self . context . node_text = token. word ;
@@ -209,7 +224,7 @@ impl PolicyParser {
209224 }
210225 "on" => self . table_with_schema ( ) ,
211226
212- // skip the "to" so we don't parse it as the TO rolename
227+ // skip the "to" so we don't parse it as the TO rolename when it's under the cursor
213228 "rename" if self . next_matches ( "to" ) => {
214229 self . advance ( ) ;
215230 }
@@ -231,17 +246,18 @@ impl PolicyParser {
231246 }
232247
233248 fn advance ( & mut self ) -> Option < WordWithIndex > {
249+ // we can't peek back n an iterator, so we'll have to keep track manually.
234250 self . previous_token = self . current_token . take ( ) ;
235251 self . current_token = self . tokens . next ( ) ;
236252 self . current_token . clone ( )
237253 }
238254
239255 fn table_with_schema ( & mut self ) {
240256 self . advance ( ) . map ( |token| {
241- if token. under_cursor ( self . cursor_position ) {
257+ if token. is_under_cursor ( self . cursor_position ) {
242258 self . handle_token_under_cursor ( token) ;
243259 } else if token. word . contains ( '.' ) {
244- let ( schema, maybe_table) = self . schema_and_table_name ( token) ;
260+ let ( schema, maybe_table) = self . schema_and_table_name ( & token) ;
245261 self . context . schema_name = Some ( schema) ;
246262 self . context . table_name = maybe_table;
247263 } else {
@@ -250,7 +266,7 @@ impl PolicyParser {
250266 } ) ;
251267 }
252268
253- fn schema_and_table_name ( & self , token : WordWithIndex ) -> ( String , Option < String > ) {
269+ fn schema_and_table_name ( & self , token : & WordWithIndex ) -> ( String , Option < String > ) {
254270 let mut parts = token. word . split ( '.' ) ;
255271
256272 (
0 commit comments