Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 179 additions & 38 deletions crates/ide-assists/src/handlers/convert_to_guarded_return.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::iter::once;

use ide_db::{
syntax_helpers::node_ext::{is_pattern_cond, single_let},
ty_filter::TryEnum,
};
use hir::Semantics;
use ide_db::{RootDatabase, ty_filter::TryEnum};
use syntax::{
AstNode,
SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
T,
SyntaxNode, T,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
Expand Down Expand Up @@ -73,13 +71,7 @@ fn if_expr_to_guarded_return(
return None;
}

// Check if there is an IfLet that we can handle.
let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
let let_ = single_let(cond)?;
(Some(let_.pat()?), let_.expr()?)
} else {
(None, cond)
};
let let_chains = flat_let_chain(cond);

let then_block = if_expr.then_branch()?;
let then_block = then_block.stmt_list()?;
Expand All @@ -106,11 +98,7 @@ fn if_expr_to_guarded_return(

let parent_container = parent_block.syntax().parent()?;

let early_expression: ast::Expr = match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
};
let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;

then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;

Expand All @@ -132,32 +120,42 @@ fn if_expr_to_guarded_return(
target,
|edit| {
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
let replacement = match if_let_pat {
None => {
// If.
let new_expr = {
let then_branch =
make::block_expr(once(make::expr_stmt(early_expression).into()), None);
let cond = invert_boolean_expression_legacy(cond_expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
Some(pat) => {
let replacement = let_chains.into_iter().map(|expr| {
if let ast::Expr::LetExpr(let_expr) = &expr
&& let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
{
// If-let.
let let_else_stmt = make::let_else_stmt(
pat,
None,
cond_expr,
ast::make::tail_only_block_expr(early_expression),
expr,
ast::make::tail_only_block_expr(early_expression.clone()),
);
let let_else_stmt = let_else_stmt.indent(if_indent_level);
let_else_stmt.syntax().clone()
} else {
// If.
let new_expr = {
let then_branch = make::block_expr(
once(make::expr_stmt(early_expression.clone()).into()),
None,
);
let cond = invert_boolean_expression_legacy(expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
};
});

let newline = &format!("\n{if_indent_level}");
let then_statements = replacement
.children_with_tokens()
.enumerate()
.flat_map(|(i, node)| {
(i != 0)
.then(|| make::tokens::whitespace(newline).into())
.into_iter()
.chain(node.children_with_tokens())
})
.chain(
then_block_items
.syntax()
Expand Down Expand Up @@ -201,11 +199,7 @@ fn let_stmt_to_guarded_return(
let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
let parent_container = parent_block.syntax().parent()?;

match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
}
early_expression(parent_container, &ctx.sema)?
};

acc.add(
Expand All @@ -232,6 +226,44 @@ fn let_stmt_to_guarded_return(
)
}

fn early_expression(
parent_container: SyntaxNode,
sema: &Semantics<'_, RootDatabase>,
) -> Option<ast::Expr> {
if let Some(fn_) = ast::Fn::cast(parent_container.clone())
&& let Some(fn_def) = sema.to_def(&fn_)
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
{
let none_expr = make::expr_path(make::ext::ident_path("None"));
return Some(make::expr_return(Some(none_expr)));
}
Some(match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
})
}

fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
let mut chains = vec![];

while let ast::Expr::BinExpr(bin_expr) = &expr
&& bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
&& let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
{
if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
} else {
chains.push(rhs);
}
expr = lhs;
}

chains.push(expr);
chains.reverse();
chains
}

#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
Expand Down Expand Up @@ -268,6 +300,37 @@ fn main() {
);
}

#[test]
fn convert_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn ret_option() -> Option<()> {
bar();
if$0 true {
foo();

// comment
bar();
}
}
"#,
r#"
fn ret_option() -> Option<()> {
bar();
if false {
return None;
}
foo();

// comment
bar();
}
"#,
);
}

#[test]
fn convert_let_inside_fn() {
check_assist(
Expand Down Expand Up @@ -316,6 +379,58 @@ fn main() {
);
}

#[test]
fn convert_if_let_chain_result() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if x >= 30 {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);

check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& y < 20
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if !(x < 30 && y < 20) {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);
}

#[test]
fn convert_let_ok_inside_fn() {
check_assist(
Expand Down Expand Up @@ -560,6 +675,32 @@ fn main() {
);
}

#[test]
fn convert_let_stmt_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn foo() -> Option<i32> {
None
}

fn ret_option() -> Option<i32> {
let x$0 = foo();
}
"#,
r#"
fn foo() -> Option<i32> {
None
}

fn ret_option() -> Option<i32> {
let Some(x) = foo() else { return None };
}
"#,
);
}

#[test]
fn convert_let_stmt_inside_loop() {
check_assist(
Expand Down
Loading