Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/pyrefly_types/src/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,15 @@ impl Param {
}
}

pub fn name(&self) -> Option<&Name> {
match self {
Param::PosOnly(name, ..) | Param::VarArg(name, ..) | Param::Kwargs(name, ..) => {
name.as_ref()
}
Param::Pos(name, ..) | Param::KwOnly(name, ..) => Some(name),
}
}

pub fn as_type(&self) -> &Type {
match self {
Param::PosOnly(_, ty, _)
Expand Down
152 changes: 125 additions & 27 deletions pyrefly/lib/state/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,17 @@ pub struct FindDefinitionItem {
pub module: Module,
}

/// The currently active argument in a function call for signature help.
#[derive(Debug)]
enum ActiveArgument {
/// The cursor is within an existing positional argument at the given index.
Positional(usize),
/// The cursor is within a keyword argument whose name is provided.
Keyword(Name),
/// The cursor is in the argument list but not inside any argument expression yet.
Next(usize),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add a doc-comment to these? it's unclear what Next is

}

impl<'a> Transaction<'a> {
fn get_type(&self, handle: &Handle, key: &Key) -> Option<Type> {
let idx = self.get_bindings(handle)?.key_to_idx(key);
Expand Down Expand Up @@ -844,42 +855,93 @@ impl<'a> Transaction<'a> {
fn visit_finding_signature_range(
x: &Expr,
find: TextSize,
res: &mut Option<(TextRange, TextRange, usize)>,
res: &mut Option<(TextRange, TextRange, ActiveArgument)>,
) {
if let Expr::Call(call) = x
&& call.arguments.range.contains_inclusive(find)
{
for (i, arg) in call.arguments.args.as_ref().iter().enumerate() {
if arg.range().contains_inclusive(find) {
Self::visit_finding_signature_range(arg, find, res);
if res.is_some() {
return;
}
*res = Some((call.func.range(), call.arguments.range, i));
}
if Self::visit_positional_signature_args(call, find, res) {
return;
}
if Self::visit_keyword_signature_args(call, find, res) {
return;
}
if res.is_none() {
*res = Some((
call.func.range(),
call.arguments.range,
call.arguments.len(),
ActiveArgument::Next(call.arguments.len()),
));
}
} else {
x.recurse(&mut |x| Self::visit_finding_signature_range(x, find, res));
}
}

fn visit_positional_signature_args(
call: &ExprCall,
find: TextSize,
res: &mut Option<(TextRange, TextRange, ActiveArgument)>,
) -> bool {
for (i, arg) in call.arguments.args.as_ref().iter().enumerate() {
if arg.range().contains_inclusive(find) {
Self::visit_finding_signature_range(arg, find, res);
if res.is_some() {
return true;
}
*res = Some((
call.func.range(),
call.arguments.range,
ActiveArgument::Positional(i),
));
return true;
}
}
false
}

fn visit_keyword_signature_args(
call: &ExprCall,
find: TextSize,
res: &mut Option<(TextRange, TextRange, ActiveArgument)>,
) -> bool {
let kwarg_start_idx = call.arguments.args.len();
for (j, kw) in call.arguments.keywords.iter().enumerate() {
if kw.range.contains_inclusive(find) {
Self::visit_finding_signature_range(&kw.value, find, res);
if res.is_some() {
return true;
}
let active_argument = match kw.arg.as_ref() {
Some(identifier) => ActiveArgument::Keyword(identifier.id.clone()),
None => ActiveArgument::Positional(kwarg_start_idx + j),
};
*res = Some((call.func.range(), call.arguments.range, active_argument));
return true;
}
}
false
}

/// Finds the callable(s) (multiple if overloads exist) at position in document, returning them, chosen overload index, and arg index
fn get_callables_from_call(
&self,
handle: &Handle,
position: TextSize,
) -> Option<(Vec<Type>, usize, usize)> {
) -> Option<(Vec<Type>, usize, ActiveArgument)> {
let mod_module = self.get_ast(handle)?;
let mut res = None;
mod_module.visit(&mut |x| Self::visit_finding_signature_range(x, position, &mut res));
let (callee_range, call_args_range, arg_index) = res?;
let (callee_range, call_args_range, mut active_argument) = res?;
// When the cursor is in the argument list but not inside any argument yet,
// estimate the would-be positional index by counting commas up to the cursor.
// This keeps signature help useful even before the user starts typing the next arg.
if let ActiveArgument::Next(index) = &mut active_argument
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here for what this part does and why it's necessary

&& let Some(next_index) =
self.count_argument_separators_before(handle, call_args_range, position)
{
*index = next_index;
}
let answers = self.get_answers(handle)?;
if let Some((overloads, chosen_overload_index)) =
answers.get_all_overload_trace(call_args_range)
Expand All @@ -888,12 +950,12 @@ impl<'a> Transaction<'a> {
Some((
callables,
chosen_overload_index.unwrap_or_default(),
arg_index,
active_argument,
))
} else {
answers
.get_type_trace(callee_range)
.map(|t| (vec![t], 0, arg_index))
.map(|t| (vec![t], 0, active_argument))
}
}

Expand All @@ -903,27 +965,33 @@ impl<'a> Transaction<'a> {
position: TextSize,
) -> Option<SignatureHelp> {
self.get_callables_from_call(handle, position).map(
|(callables, chosen_overload_index, arg_index)| SignatureHelp {
signatures: callables
|(callables, chosen_overload_index, active_argument)| {
let signatures = callables
.into_iter()
.map(|t| Self::create_signature_information(t, arg_index))
.collect_vec(),
active_signature: Some(chosen_overload_index as u32),
active_parameter: Some(arg_index as u32),
.map(|t| Self::create_signature_information(t, &active_argument))
.collect_vec();
let active_parameter = signatures
.get(chosen_overload_index)
.and_then(|info| info.active_parameter);
SignatureHelp {
signatures,
active_signature: Some(chosen_overload_index as u32),
active_parameter,
}
},
)
}

fn create_signature_information(type_: Type, arg_index: usize) -> SignatureInformation {
fn create_signature_information(
type_: Type,
active_argument: &ActiveArgument,
) -> SignatureInformation {
let type_ = type_.deterministic_printing();
let label = type_.as_hover_string();
let (parameters, active_parameter) =
if let Some(params) = Self::normalize_singleton_function_type_into_params(type_) {
let active_parameter = if arg_index < params.len() {
Some(arg_index as u32)
} else {
None
};
let active_parameter =
Self::active_parameter_index(&params, active_argument).map(|idx| idx as u32);
(
Some(params.map(|param| ParameterInformation {
label: ParameterLabel::Simple(format!("{param}")),
Expand All @@ -942,6 +1010,35 @@ impl<'a> Transaction<'a> {
}
}

fn active_parameter_index(params: &[Param], active_argument: &ActiveArgument) -> Option<usize> {
match active_argument {
ActiveArgument::Positional(index) | ActiveArgument::Next(index) => {
(*index < params.len()).then_some(*index)
}
ActiveArgument::Keyword(name) => params
.iter()
.position(|param| param.name().is_some_and(|param_name| param_name == name)),
}
}

fn count_argument_separators_before(
&self,
handle: &Handle,
arguments_range: TextRange,
position: TextSize,
) -> Option<usize> {
let module = self.get_module_info(handle)?;
let contents = module.contents();
let len = contents.len();
let start = arguments_range.start().to_usize().min(len);
let end = arguments_range.end().to_usize().min(len);
let pos = position.to_usize().clamp(start, end);
contents
.get(start..pos)
.map(|slice| slice.bytes().filter(|&b| b == b',').count())
.or(Some(0))
}

fn normalize_singleton_function_type_into_params(type_: Type) -> Option<Vec<Param>> {
let callable = type_.to_callable()?;
// We will drop the self parameter for signature help
Expand Down Expand Up @@ -2169,11 +2266,12 @@ impl<'a> Transaction<'a> {
position: TextSize,
completions: &mut Vec<CompletionItem>,
) {
if let Some((callables, chosen_overload_index, arg_index)) =
if let Some((callables, chosen_overload_index, active_argument)) =
self.get_callables_from_call(handle, position)
&& let Some(callable) = callables.get(chosen_overload_index)
&& let Some(params) =
Self::normalize_singleton_function_type_into_params(callable.clone())
&& let Some(arg_index) = Self::active_parameter_index(&params, &active_argument)
&& let Some(param) = params.get(arg_index)
{
Self::add_literal_completions_from_type(param.as_type(), completions);
Expand Down
93 changes: 93 additions & 0 deletions pyrefly/lib/test/lsp/signature_help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,99 @@ Signature Help Result: active=0
);
}

#[test]
fn positional_arguments_test() {
let code = r#"
def f(x: int, y: int, z: int) -> None: ...

f(1,,)
# ^
f(1,,)
# ^
f(1,,3)
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert_eq!(
r#"
# main.py
4 | f(1,,)
^
Signature Help Result: active=0
- def f(
x: int,
y: int,
z: int
) -> None: ..., parameters=[x: int, y: int, z: int], active parameter = 1

6 | f(1,,)
^
Signature Help Result: active=0
- def f(
x: int,
y: int,
z: int
) -> None: ..., parameters=[x: int, y: int, z: int], active parameter = 2

8 | f(1,,3)
^
Signature Help Result: active=0
- def f(
x: int,
y: int,
z: int
) -> None: ..., parameters=[x: int, y: int, z: int], active parameter = 1
"#
.trim(),
report.trim(),
);
}

#[test]
fn keyword_arguments_test() {
let code = r#"
def f(a: str, b: int) -> None: ...

f(a)
# ^
f(a=)
# ^
f(b=)
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert_eq!(
r#"
# main.py
4 | f(a)
^
Signature Help Result: active=0
- def f(
a: str,
b: int
) -> None: ..., parameters=[a: str, b: int], active parameter = 0

6 | f(a=)
^
Signature Help Result: active=0
- def f(
a: str,
b: int
) -> None: ..., parameters=[a: str, b: int], active parameter = 0

8 | f(b=)
^
Signature Help Result: active=0
- def f(
a: str,
b: int
) -> None: ..., parameters=[a: str, b: int], active parameter = 1
"#
.trim(),
report.trim(),
);
}

#[test]
fn simple_incomplete_function_call_test() {
let code = r#"
Expand Down