@@ -3,9 +3,9 @@ use std::num::NonZeroUsize;
33use std:: sync:: { Arc , LazyLock , Mutex } ;
44
55use lru:: LruCache ;
6+ use pgt_lexer:: lex;
67use pgt_query_ext:: diagnostics:: * ;
78use pgt_text_size:: TextRange ;
8- use pgt_tokenizer:: tokenize;
99use regex:: Regex ;
1010
1111use super :: statement_identifier:: StatementId ;
@@ -104,6 +104,27 @@ fn is_composite_type_error(err: &str) -> bool {
104104 COMPOSITE_TYPE_ERROR_RE . is_match ( err)
105105}
106106
107+ // Keywords that, when preceding a named parameter, indicate that the parameter should be treated
108+ // as an identifier rather than a positional parameter.
109+ const IDENTIFIER_CONTEXT : [ pgt_lexer:: SyntaxKind ; 15 ] = [
110+ pgt_lexer:: SyntaxKind :: TO_KW ,
111+ pgt_lexer:: SyntaxKind :: FROM_KW ,
112+ pgt_lexer:: SyntaxKind :: SCHEMA_KW ,
113+ pgt_lexer:: SyntaxKind :: TABLE_KW ,
114+ pgt_lexer:: SyntaxKind :: INDEX_KW ,
115+ pgt_lexer:: SyntaxKind :: CONSTRAINT_KW ,
116+ pgt_lexer:: SyntaxKind :: OWNER_KW ,
117+ pgt_lexer:: SyntaxKind :: ROLE_KW ,
118+ pgt_lexer:: SyntaxKind :: USER_KW ,
119+ pgt_lexer:: SyntaxKind :: DATABASE_KW ,
120+ pgt_lexer:: SyntaxKind :: TYPE_KW ,
121+ pgt_lexer:: SyntaxKind :: CAST_KW ,
122+ pgt_lexer:: SyntaxKind :: ALTER_KW ,
123+ pgt_lexer:: SyntaxKind :: DROP_KW ,
124+ // for schema.table style identifiers
125+ pgt_lexer:: SyntaxKind :: DOT ,
126+ ] ;
127+
107128/// Converts named parameters in a SQL query string to positional parameters.
108129///
109130/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`)
@@ -116,13 +137,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
116137 let mut result = String :: with_capacity ( text. len ( ) ) ;
117138 let mut param_mapping: HashMap < & str , usize > = HashMap :: new ( ) ;
118139 let mut param_index = 1 ;
119- let mut position = 0 ;
120140
121- for token in tokenize ( text) {
122- let token_len = token. len as usize ;
123- let token_text = & text[ position..position + token_len] ;
141+ let lexed = lex ( text) ;
142+ for ( token_idx, kind) in lexed. tokens ( ) . enumerate ( ) {
143+ if kind == pgt_lexer:: SyntaxKind :: EOF {
144+ break ;
145+ }
146+
147+ let token_text = lexed. text ( token_idx) ;
124148
125- if matches ! ( token . kind, pgt_tokenizer :: TokenKind :: NamedParam { .. } ) {
149+ if matches ! ( kind, pgt_lexer :: SyntaxKind :: NAMED_PARAM ) {
126150 let idx = match param_mapping. get ( token_text) {
127151 Some ( & index) => index,
128152 None => {
@@ -133,7 +157,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
133157 }
134158 } ;
135159
136- let replacement = format ! ( "${}" , idx) ;
160+ // find previous non-trivia token
161+ let prev_token = ( 0 ..token_idx)
162+ . rev ( )
163+ . map ( |i| lexed. kind ( i) )
164+ . find ( |kind| !kind. is_trivia ( ) ) ;
165+
166+ let replacement = match prev_token {
167+ Some ( k) if IDENTIFIER_CONTEXT . contains ( & k) => deterministic_identifier ( idx - 1 ) ,
168+ _ => format ! ( "${}" , idx) ,
169+ } ;
137170 let original_len = token_text. len ( ) ;
138171 let replacement_len = replacement. len ( ) ;
139172
@@ -146,17 +179,45 @@ pub fn convert_to_positional_params(text: &str) -> String {
146179 } else {
147180 result. push_str ( token_text) ;
148181 }
149-
150- position += token_len;
151182 }
152183
153184 result
154185}
155186
187+ const ALPHABET : [ char ; 26 ] = [
188+ 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' , 'o' , 'p' , 'q' , 'r' , 's' ,
189+ 't' , 'u' , 'v' , 'w' , 'x' , 'y' , 'z' ,
190+ ] ;
191+
192+ /// Generates a deterministic identifier based on the given index.
193+ fn deterministic_identifier ( idx : usize ) -> String {
194+ let iteration = idx / ALPHABET . len ( ) ;
195+ let pos = idx % ALPHABET . len ( ) ;
196+
197+ format ! (
198+ "{}{}" ,
199+ ALPHABET [ pos] ,
200+ if iteration > 0 {
201+ deterministic_identifier( iteration - 1 )
202+ } else {
203+ "" . to_string( )
204+ }
205+ )
206+ }
207+
156208#[ cfg( test) ]
157209mod tests {
158210 use super :: * ;
159211
212+ #[ test]
213+ fn test_deterministic_identifier ( ) {
214+ assert_eq ! ( deterministic_identifier( 0 ) , "a" ) ;
215+ assert_eq ! ( deterministic_identifier( 25 ) , "z" ) ;
216+ assert_eq ! ( deterministic_identifier( 26 ) , "aa" ) ;
217+ assert_eq ! ( deterministic_identifier( 27 ) , "ba" ) ;
218+ assert_eq ! ( deterministic_identifier( 51 ) , "za" ) ;
219+ }
220+
160221 #[ test]
161222 fn test_convert_to_positional_params ( ) {
162223 let input = "select * from users where id = @one and name = :two and email = :'three';" ;
@@ -177,6 +238,24 @@ mod tests {
177238 ) ;
178239 }
179240
241+ #[ test]
242+ fn test_positional_params_in_grant ( ) {
243+ let input = "grant usage on schema public, app_public, app_hidden to :DB_ROLE;" ;
244+
245+ let result = convert_to_positional_params ( input) ;
246+
247+ assert_eq ! (
248+ result,
249+ "grant usage on schema public, app_public, app_hidden to a ;"
250+ ) ;
251+
252+ let store = PgQueryStore :: new ( ) ;
253+
254+ let res = store. get_or_cache_ast ( & StatementId :: new ( input) ) ;
255+
256+ assert ! ( res. is_ok( ) ) ;
257+ }
258+
180259 #[ test]
181260 fn test_plpgsql_syntax_error ( ) {
182261 let input = "
0 commit comments