Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 252 additions & 45 deletions crates/ide-assists/src/handlers/convert_to_guarded_return.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::iter::once;

use ide_db::{
syntax_helpers::node_ext::{is_pattern_cond, single_let},
ty_filter::TryEnum,
};
use either::Either;
use hir::{Semantics, TypeInfo};
use ide_db::{RootDatabase, ty_filter::TryEnum};
use syntax::{
AstNode,
SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
T,
SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
SyntaxNode, T,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
Expand Down Expand Up @@ -44,12 +43,9 @@ use crate::{
// }
// ```
pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
if let Some(let_stmt) = ctx.find_node_at_offset() {
let_stmt_to_guarded_return(let_stmt, acc, ctx)
} else if let Some(if_expr) = ctx.find_node_at_offset() {
if_expr_to_guarded_return(if_expr, acc, ctx)
} else {
None
match ctx.find_node_at_offset::<Either<ast::LetStmt, ast::IfExpr>>()? {
Either::Left(let_stmt) => let_stmt_to_guarded_return(let_stmt, acc, ctx),
Either::Right(if_expr) => if_expr_to_guarded_return(if_expr, acc, ctx),
}
}

Expand All @@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
return None;
}

// Check if there is an IfLet that we can handle.
let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
let let_ = single_let(cond)?;
(Some(let_.pat()?), let_.expr()?)
} else {
(None, cond)
};
let let_chains = flat_let_chain(cond);

let then_block = if_expr.then_branch()?;
let then_block = then_block.stmt_list()?;
Expand All @@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(

let parent_container = parent_block.syntax().parent()?;

let early_expression: ast::Expr = match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
};
let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;

then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;

Expand All @@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
target,
|edit| {
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
let replacement = match if_let_pat {
None => {
// If.
let new_expr = {
let then_branch =
make::block_expr(once(make::expr_stmt(early_expression).into()), None);
let cond = invert_boolean_expression_legacy(cond_expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
Some(pat) => {
let replacement = let_chains.into_iter().map(|expr| {
if let ast::Expr::LetExpr(let_expr) = &expr
&& let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
{
// If-let.
let let_else_stmt = make::let_else_stmt(
pat,
None,
cond_expr,
ast::make::tail_only_block_expr(early_expression),
expr,
ast::make::tail_only_block_expr(early_expression.clone()),
);
let let_else_stmt = let_else_stmt.indent(if_indent_level);
let_else_stmt.syntax().clone()
} else {
// If.
let new_expr = {
let then_branch = make::block_expr(
once(make::expr_stmt(early_expression.clone()).into()),
None,
);
let cond = invert_boolean_expression_legacy(expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
};
});

let newline = &format!("\n{if_indent_level}");
let then_statements = replacement
.children_with_tokens()
.enumerate()
.flat_map(|(i, node)| {
(i != 0)
.then(|| make::tokens::whitespace(newline).into())
.into_iter()
.chain(node.children_with_tokens())
})
.chain(
then_block_items
.syntax()
Expand Down Expand Up @@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
let parent_container = parent_block.syntax().parent()?;

match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
}
early_expression(parent_container, &ctx.sema)?
};

acc.add(
Expand All @@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
)
}

fn early_expression(
parent_container: SyntaxNode,
sema: &Semantics<'_, RootDatabase>,
) -> Option<ast::Expr> {
let return_none_expr = || {
let none_expr = make::expr_path(make::ext::ident_path("None"));
make::expr_return(Some(none_expr))
};
if let Some(fn_) = ast::Fn::cast(parent_container.clone())
&& let Some(fn_def) = sema.to_def(&fn_)
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
{
return Some(return_none_expr());
}
if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
&& let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
{
return Some(return_none_expr());
}

Some(match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN | CLOSURE_EXPR => make::expr_return(None),
_ => return None,
})
}

fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
let mut chains = vec![];

while let ast::Expr::BinExpr(bin_expr) = &expr
&& bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
&& let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
{
if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
} else {
chains.push(rhs);
}
expr = lhs;
}

chains.push(expr);
chains.reverse();
chains
}

#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
Expand Down Expand Up @@ -268,6 +308,71 @@ fn main() {
);
}

#[test]
fn convert_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn ret_option() -> Option<()> {
bar();
if$0 true {
foo();

// comment
bar();
}
}
"#,
r#"
fn ret_option() -> Option<()> {
bar();
if false {
return None;
}
foo();

// comment
bar();
}
"#,
);
}

#[test]
fn convert_inside_closure() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
let _f = || {
bar();
if$0 true {
foo();

// comment
bar();
}
}
}
"#,
r#"
fn main() {
let _f = || {
bar();
if false {
return;
}
foo();

// comment
bar();
}
}
"#,
);
}

#[test]
fn convert_let_inside_fn() {
check_assist(
Expand Down Expand Up @@ -316,6 +421,82 @@ fn main() {
);
}

#[test]
fn convert_if_let_result_inside_let() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
let _x = loop {
if$0 let Ok(x) = Err(92) {
foo(x);
}
};
}
"#,
r#"
fn main() {
let _x = loop {
let Ok(x) = Err(92) else { continue };
foo(x);
};
}
"#,
);
}

#[test]
fn convert_if_let_chain_result() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if x >= 30 {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);

check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& y < 20
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if !(x < 30 && y < 20) {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);
}

#[test]
fn convert_let_ok_inside_fn() {
check_assist(
Expand Down Expand Up @@ -560,6 +741,32 @@ fn main() {
);
}

#[test]
fn convert_let_stmt_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn foo() -> Option<i32> {
None
}

fn ret_option() -> Option<i32> {
let x$0 = foo();
}
"#,
r#"
fn foo() -> Option<i32> {
None
}

fn ret_option() -> Option<i32> {
let Some(x) = foo() else { return None };
}
"#,
);
}

#[test]
fn convert_let_stmt_inside_loop() {
check_assist(
Expand Down
Loading