diff --git a/pyrefly/lib/lsp/wasm/completion.rs b/pyrefly/lib/lsp/wasm/completion.rs index 25a725a61d..6fae1d18e5 100644 --- a/pyrefly/lib/lsp/wasm/completion.rs +++ b/pyrefly/lib/lsp/wasm/completion.rs @@ -30,6 +30,7 @@ use ruff_python_ast::AnyNodeRef; use ruff_python_ast::ExprContext; use ruff_python_ast::Identifier; use ruff_python_ast::ModModule; +use ruff_python_ast::Pattern; use ruff_python_ast::name::Name; use ruff_text_size::Ranged; use ruff_text_size::TextRange; @@ -72,6 +73,7 @@ pub(crate) struct RankedCompletion { pub(crate) item: CompletionItem, pub(crate) source: CompletionSource, pub(crate) is_incompatible: bool, + pub(crate) is_deprioritized: bool, } impl RankedCompletion { @@ -81,6 +83,7 @@ impl RankedCompletion { item, source: CompletionSource::Local, is_incompatible: false, + is_deprioritized: false, } } } @@ -119,7 +122,7 @@ fn assign_sort_text(ranked: &mut RankedCompletion, mru_rank: Option { }, source: autoimport_source(module_name_str), is_incompatible: false, + is_deprioritized: false, }); Some(module_name) } @@ -594,6 +598,7 @@ impl Transaction<'_> { }, source: CompletionSource::Local, is_incompatible, + is_deprioritized: false, }) } } @@ -721,6 +726,7 @@ impl Transaction<'_> { }, source: autoimport_source(&imported_module), is_incompatible: false, + is_deprioritized: false, }); } @@ -758,6 +764,7 @@ impl Transaction<'_> { }, source, is_incompatible: false, + is_deprioritized: false, }); } if let Some(module_handle) = self.import_handle(handle, module_name, None).finding() @@ -790,6 +797,7 @@ impl Transaction<'_> { }, source, is_incompatible: false, + is_deprioritized: false, }); } } @@ -888,10 +896,98 @@ impl Transaction<'_> { }, source, is_incompatible, + is_deprioritized: false, }); }); }); } + /// Demote enum members already covered by earlier `case` arms in the same `match`. + fn deprioritize_previously_matched_enum_members( + &self, + handle: &Handle, + position: TextSize, + enum_type: &Type, + completions: &mut [RankedCompletion], + ) { + let Some(ast) = self.get_ast(handle) else { + return; + }; + let nodes = Ast::locate_node(ast.as_ref(), position); + if !nodes + .iter() + .any(|node| matches!(node, AnyNodeRef::PatternMatchValue(_))) + { + return; + } + let Some(stmt_match) = nodes.iter().find_map(|node| match node { + AnyNodeRef::StmtMatch(stmt_match) => Some(stmt_match), + _ => None, + }) else { + return; + }; + let Some(current_case_idx) = stmt_match + .cases + .iter() + .position(|case| case.range.contains_inclusive(position)) + .or_else(|| { + stmt_match + .cases + .iter() + .rposition(|case| case.range.start() <= position) + }) + else { + return; + }; + + let mut matched_members = SmallSet::new(); + for case in stmt_match.cases.iter().take(current_case_idx) { + self.collect_matched_enum_members( + handle, + enum_type, + &case.pattern, + &mut matched_members, + ); + } + if matched_members.is_empty() { + return; + } + for completion in completions { + if matched_members.contains(&completion.item.label) { + completion.is_deprioritized = true; + } + } + } + + fn collect_matched_enum_members( + &self, + handle: &Handle, + enum_type: &Type, + pattern: &Pattern, + matched_members: &mut SmallSet, + ) { + match pattern { + Pattern::MatchValue(pattern) => { + if let Some(value_type) = self.get_type_trace(handle, pattern.value.range()) + && value_type.qname() == enum_type.qname() + && let Type::Literal(lit) = value_type + && let Lit::Enum(lit_enum) = lit.value + { + matched_members.insert(lit_enum.member.as_str().to_owned()); + } + } + Pattern::MatchAs(pattern) => { + if let Some(pattern) = pattern.pattern.as_deref() { + self.collect_matched_enum_members(handle, enum_type, pattern, matched_members); + } + } + Pattern::MatchOr(pattern) => { + for pattern in &pattern.patterns { + self.collect_matched_enum_members(handle, enum_type, pattern, matched_members); + } + } + _ => {} + } + } /// Core completion implementation returning items and incomplete flag. pub(crate) fn completion_sorted_opt_with_incomplete( @@ -1043,10 +1139,16 @@ impl Transaction<'_> { { self.add_attribute_completions_for_type( handle, - base_type, + base_type.clone(), expected_type.as_ref(), &mut result, ); + self.deprioritize_previously_matched_enum_members( + handle, + position, + &base_type, + &mut result, + ); } } Some(IdentifierWithContext { diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index 8ab41d3031..1146111df1 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -1461,6 +1461,51 @@ Completion Results: ); } +#[test] +fn completion_demotes_previously_matched_enum_members() { + let code = r#" +from enum import StrEnum, auto + +class A(StrEnum): + AA = auto() + BB = auto() + +def f(a: A): + match a: + case A.AA: + ... + case A. +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error( + &[("main", code)], + |state, handle, position| { + let mut report = String::new(); + for item in state + .transaction() + .completion(handle, position, ImportFormat::Absolute, true, None) + .into_iter() + .filter(|item| matches!(item.label.as_str(), "AA" | "BB")) + { + report.push_str(&item.label); + report.push('\n'); + } + report + }, + ); + + let bb_index = report.find("BB\n"); + let aa_index = report.find("AA\n"); + assert!( + bb_index.is_some() && aa_index.is_some(), + "Expected completions for AA and BB." + ); + assert!( + bb_index.unwrap() < aa_index.unwrap(), + "Expected the unmatched enum member to sort first." + ); +} + #[test] fn completion_literal_union_alias() { let code = r#"