1
1
use std:: iter:: once;
2
2
3
- use ide_db:: {
4
- syntax_helpers:: node_ext:: { is_pattern_cond, single_let} ,
5
- ty_filter:: TryEnum ,
6
- } ;
3
+ use either:: Either ;
4
+ use hir:: { Semantics , TypeInfo } ;
5
+ use ide_db:: { RootDatabase , ty_filter:: TryEnum } ;
7
6
use syntax:: {
8
7
AstNode ,
9
- SyntaxKind :: { FN , FOR_EXPR , LOOP_EXPR , WHILE_EXPR , WHITESPACE } ,
10
- T ,
8
+ SyntaxKind :: { CLOSURE_EXPR , FN , FOR_EXPR , LOOP_EXPR , WHILE_EXPR , WHITESPACE } ,
9
+ SyntaxNode , T ,
11
10
ast:: {
12
11
self ,
13
12
edit:: { AstNodeEdit , IndentLevel } ,
@@ -44,12 +43,9 @@ use crate::{
44
43
// }
45
44
// ```
46
45
pub ( crate ) fn convert_to_guarded_return ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
47
- if let Some ( let_stmt) = ctx. find_node_at_offset ( ) {
48
- let_stmt_to_guarded_return ( let_stmt, acc, ctx)
49
- } else if let Some ( if_expr) = ctx. find_node_at_offset ( ) {
50
- if_expr_to_guarded_return ( if_expr, acc, ctx)
51
- } else {
52
- None
46
+ match ctx. find_node_at_offset :: < Either < ast:: LetStmt , ast:: IfExpr > > ( ) ? {
47
+ Either :: Left ( let_stmt) => let_stmt_to_guarded_return ( let_stmt, acc, ctx) ,
48
+ Either :: Right ( if_expr) => if_expr_to_guarded_return ( if_expr, acc, ctx) ,
53
49
}
54
50
}
55
51
@@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
73
69
return None ;
74
70
}
75
71
76
- // Check if there is an IfLet that we can handle.
77
- let ( if_let_pat, cond_expr) = if is_pattern_cond ( cond. clone ( ) ) {
78
- let let_ = single_let ( cond) ?;
79
- ( Some ( let_. pat ( ) ?) , let_. expr ( ) ?)
80
- } else {
81
- ( None , cond)
82
- } ;
72
+ let let_chains = flat_let_chain ( cond) ;
83
73
84
74
let then_block = if_expr. then_branch ( ) ?;
85
75
let then_block = then_block. stmt_list ( ) ?;
@@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
106
96
107
97
let parent_container = parent_block. syntax ( ) . parent ( ) ?;
108
98
109
- let early_expression: ast:: Expr = match parent_container. kind ( ) {
110
- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
111
- FN => make:: expr_return ( None ) ,
112
- _ => return None ,
113
- } ;
99
+ let early_expression: ast:: Expr = early_expression ( parent_container, & ctx. sema ) ?;
114
100
115
101
then_block. syntax ( ) . first_child_or_token ( ) . map ( |t| t. kind ( ) == T ! [ '{' ] ) ?;
116
102
@@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
132
118
target,
133
119
|edit| {
134
120
let if_indent_level = IndentLevel :: from_node ( if_expr. syntax ( ) ) ;
135
- let replacement = match if_let_pat {
136
- None => {
137
- // If.
138
- let new_expr = {
139
- let then_branch =
140
- make:: block_expr ( once ( make:: expr_stmt ( early_expression) . into ( ) ) , None ) ;
141
- let cond = invert_boolean_expression_legacy ( cond_expr) ;
142
- make:: expr_if ( cond, then_branch, None ) . indent ( if_indent_level)
143
- } ;
144
- new_expr. syntax ( ) . clone ( )
145
- }
146
- Some ( pat) => {
121
+ let replacement = let_chains. into_iter ( ) . map ( |expr| {
122
+ if let ast:: Expr :: LetExpr ( let_expr) = & expr
123
+ && let ( Some ( pat) , Some ( expr) ) = ( let_expr. pat ( ) , let_expr. expr ( ) )
124
+ {
147
125
// If-let.
148
126
let let_else_stmt = make:: let_else_stmt (
149
127
pat,
150
128
None ,
151
- cond_expr ,
152
- ast:: make:: tail_only_block_expr ( early_expression) ,
129
+ expr ,
130
+ ast:: make:: tail_only_block_expr ( early_expression. clone ( ) ) ,
153
131
) ;
154
132
let let_else_stmt = let_else_stmt. indent ( if_indent_level) ;
155
133
let_else_stmt. syntax ( ) . clone ( )
134
+ } else {
135
+ // If.
136
+ let new_expr = {
137
+ let then_branch = make:: block_expr (
138
+ once ( make:: expr_stmt ( early_expression. clone ( ) ) . into ( ) ) ,
139
+ None ,
140
+ ) ;
141
+ let cond = invert_boolean_expression_legacy ( expr) ;
142
+ make:: expr_if ( cond, then_branch, None ) . indent ( if_indent_level)
143
+ } ;
144
+ new_expr. syntax ( ) . clone ( )
156
145
}
157
- } ;
146
+ } ) ;
158
147
148
+ let newline = & format ! ( "\n {if_indent_level}" ) ;
159
149
let then_statements = replacement
160
- . children_with_tokens ( )
150
+ . enumerate ( )
151
+ . flat_map ( |( i, node) | {
152
+ ( i != 0 )
153
+ . then ( || make:: tokens:: whitespace ( newline) . into ( ) )
154
+ . into_iter ( )
155
+ . chain ( node. children_with_tokens ( ) )
156
+ } )
161
157
. chain (
162
158
then_block_items
163
159
. syntax ( )
@@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
201
197
let_stmt. syntax ( ) . parent ( ) ?. ancestors ( ) . find_map ( ast:: BlockExpr :: cast) ?;
202
198
let parent_container = parent_block. syntax ( ) . parent ( ) ?;
203
199
204
- match parent_container. kind ( ) {
205
- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
206
- FN => make:: expr_return ( None ) ,
207
- _ => return None ,
208
- }
200
+ early_expression ( parent_container, & ctx. sema ) ?
209
201
} ;
210
202
211
203
acc. add (
@@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
232
224
)
233
225
}
234
226
227
+ fn early_expression (
228
+ parent_container : SyntaxNode ,
229
+ sema : & Semantics < ' _ , RootDatabase > ,
230
+ ) -> Option < ast:: Expr > {
231
+ let return_none_expr = || {
232
+ let none_expr = make:: expr_path ( make:: ext:: ident_path ( "None" ) ) ;
233
+ make:: expr_return ( Some ( none_expr) )
234
+ } ;
235
+ if let Some ( fn_) = ast:: Fn :: cast ( parent_container. clone ( ) )
236
+ && let Some ( fn_def) = sema. to_def ( & fn_)
237
+ && let Some ( TryEnum :: Option ) = TryEnum :: from_ty ( sema, & fn_def. ret_type ( sema. db ) )
238
+ {
239
+ return Some ( return_none_expr ( ) ) ;
240
+ }
241
+ if let Some ( body) = ast:: ClosureExpr :: cast ( parent_container. clone ( ) ) . and_then ( |it| it. body ( ) )
242
+ && let Some ( ret_ty) = sema. type_of_expr ( & body) . map ( TypeInfo :: original)
243
+ && let Some ( TryEnum :: Option ) = TryEnum :: from_ty ( sema, & ret_ty)
244
+ {
245
+ return Some ( return_none_expr ( ) ) ;
246
+ }
247
+
248
+ Some ( match parent_container. kind ( ) {
249
+ WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
250
+ FN | CLOSURE_EXPR => make:: expr_return ( None ) ,
251
+ _ => return None ,
252
+ } )
253
+ }
254
+
255
+ fn flat_let_chain ( mut expr : ast:: Expr ) -> Vec < ast:: Expr > {
256
+ let mut chains = vec ! [ ] ;
257
+
258
+ while let ast:: Expr :: BinExpr ( bin_expr) = & expr
259
+ && bin_expr. op_kind ( ) == Some ( ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) )
260
+ && let ( Some ( lhs) , Some ( rhs) ) = ( bin_expr. lhs ( ) , bin_expr. rhs ( ) )
261
+ {
262
+ if let Some ( last) = chains. pop_if ( |last| !matches ! ( last, ast:: Expr :: LetExpr ( _) ) ) {
263
+ chains. push ( make:: expr_bin_op ( rhs, ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) , last) ) ;
264
+ } else {
265
+ chains. push ( rhs) ;
266
+ }
267
+ expr = lhs;
268
+ }
269
+
270
+ chains. push ( expr) ;
271
+ chains. reverse ( ) ;
272
+ chains
273
+ }
274
+
235
275
#[ cfg( test) ]
236
276
mod tests {
237
277
use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -268,6 +308,71 @@ fn main() {
268
308
) ;
269
309
}
270
310
311
+ #[ test]
312
+ fn convert_inside_fn_return_option ( ) {
313
+ check_assist (
314
+ convert_to_guarded_return,
315
+ r#"
316
+ //- minicore: option
317
+ fn ret_option() -> Option<()> {
318
+ bar();
319
+ if$0 true {
320
+ foo();
321
+
322
+ // comment
323
+ bar();
324
+ }
325
+ }
326
+ "# ,
327
+ r#"
328
+ fn ret_option() -> Option<()> {
329
+ bar();
330
+ if false {
331
+ return None;
332
+ }
333
+ foo();
334
+
335
+ // comment
336
+ bar();
337
+ }
338
+ "# ,
339
+ ) ;
340
+ }
341
+
342
+ #[ test]
343
+ fn convert_inside_closure ( ) {
344
+ check_assist (
345
+ convert_to_guarded_return,
346
+ r#"
347
+ fn main() {
348
+ let _f = || {
349
+ bar();
350
+ if$0 true {
351
+ foo();
352
+
353
+ // comment
354
+ bar();
355
+ }
356
+ }
357
+ }
358
+ "# ,
359
+ r#"
360
+ fn main() {
361
+ let _f = || {
362
+ bar();
363
+ if false {
364
+ return;
365
+ }
366
+ foo();
367
+
368
+ // comment
369
+ bar();
370
+ }
371
+ }
372
+ "# ,
373
+ ) ;
374
+ }
375
+
271
376
#[ test]
272
377
fn convert_let_inside_fn ( ) {
273
378
check_assist (
@@ -316,6 +421,82 @@ fn main() {
316
421
) ;
317
422
}
318
423
424
+ #[ test]
425
+ fn convert_if_let_result_inside_let ( ) {
426
+ check_assist (
427
+ convert_to_guarded_return,
428
+ r#"
429
+ fn main() {
430
+ let _x = loop {
431
+ if$0 let Ok(x) = Err(92) {
432
+ foo(x);
433
+ }
434
+ };
435
+ }
436
+ "# ,
437
+ r#"
438
+ fn main() {
439
+ let _x = loop {
440
+ let Ok(x) = Err(92) else { continue };
441
+ foo(x);
442
+ };
443
+ }
444
+ "# ,
445
+ ) ;
446
+ }
447
+
448
+ #[ test]
449
+ fn convert_if_let_chain_result ( ) {
450
+ check_assist (
451
+ convert_to_guarded_return,
452
+ r#"
453
+ fn main() {
454
+ if$0 let Ok(x) = Err(92)
455
+ && x < 30
456
+ && let Some(y) = Some(8)
457
+ {
458
+ foo(x, y);
459
+ }
460
+ }
461
+ "# ,
462
+ r#"
463
+ fn main() {
464
+ let Ok(x) = Err(92) else { return };
465
+ if x >= 30 {
466
+ return;
467
+ }
468
+ let Some(y) = Some(8) else { return };
469
+ foo(x, y);
470
+ }
471
+ "# ,
472
+ ) ;
473
+
474
+ check_assist (
475
+ convert_to_guarded_return,
476
+ r#"
477
+ fn main() {
478
+ if$0 let Ok(x) = Err(92)
479
+ && x < 30
480
+ && y < 20
481
+ && let Some(y) = Some(8)
482
+ {
483
+ foo(x, y);
484
+ }
485
+ }
486
+ "# ,
487
+ r#"
488
+ fn main() {
489
+ let Ok(x) = Err(92) else { return };
490
+ if !(x < 30 && y < 20) {
491
+ return;
492
+ }
493
+ let Some(y) = Some(8) else { return };
494
+ foo(x, y);
495
+ }
496
+ "# ,
497
+ ) ;
498
+ }
499
+
319
500
#[ test]
320
501
fn convert_let_ok_inside_fn ( ) {
321
502
check_assist (
@@ -560,6 +741,32 @@ fn main() {
560
741
) ;
561
742
}
562
743
744
+ #[ test]
745
+ fn convert_let_stmt_inside_fn_return_option ( ) {
746
+ check_assist (
747
+ convert_to_guarded_return,
748
+ r#"
749
+ //- minicore: option
750
+ fn foo() -> Option<i32> {
751
+ None
752
+ }
753
+
754
+ fn ret_option() -> Option<i32> {
755
+ let x$0 = foo();
756
+ }
757
+ "# ,
758
+ r#"
759
+ fn foo() -> Option<i32> {
760
+ None
761
+ }
762
+
763
+ fn ret_option() -> Option<i32> {
764
+ let Some(x) = foo() else { return None };
765
+ }
766
+ "# ,
767
+ ) ;
768
+ }
769
+
563
770
#[ test]
564
771
fn convert_let_stmt_inside_loop ( ) {
565
772
check_assist (
0 commit comments