Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 151 additions & 1 deletion pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use pyrefly_types::types::NNModuleType;
use pyrefly_types::types::TArgs;
use pyrefly_types::types::TParams;
use pyrefly_types::types::Union;
use pyrefly_util::display::count;
use pyrefly_util::prelude::SliceExt;
use pyrefly_util::prelude::VecExt;
use ruff_python_ast::Arguments;
Expand Down Expand Up @@ -53,8 +54,10 @@ use crate::types::callable::Callable;
use crate::types::callable::FuncMetadata;
use crate::types::callable::Function;
use crate::types::callable::FunctionKind;
use crate::types::callable::Param;
use crate::types::callable::ParamList;
use crate::types::callable::Params;
use crate::types::callable::Required;
use crate::types::class::ClassType;
use crate::types::keywords::KwCall;
use crate::types::keywords::TypeMap;
Expand Down Expand Up @@ -980,14 +983,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
.solver()
.finish_quantified(vs, self.solver().infer_with_first_use)
.err();
let result = if let Some(mut ret) = dunder_new_ret {
let partial_callable = self.functools_partial_callable(&cls, args, keywords, &errors);
let mut result = if let Some(mut ret) = dunder_new_ret {
ret.subst_mut(&cls.targs().substitution_map());
ret
} else if constructor_kind == ConstructorKind::TypeOfSelf {
self.heap.mk_self_type(cls)
} else {
self.heap.mk_class_type(cls)
};
if let Some(partial_callable) = partial_callable {
let partial_instance = result.clone();
result = self.heap.mk_intersect(
vec![partial_instance, partial_callable.clone()],
partial_callable,
);
}
// Normalize builtins.tuple instances to structural Type::Tuple so downstream
// match arms (concat, unpacking, except, etc.) handle them directly.
if let Type::ClassType(ref ct) = result
Expand Down Expand Up @@ -1026,6 +1037,145 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn functools_partial_callable(
&self,
cls: &ClassType,
args: &[CallArg],
keywords: &[CallKeyword],
errors: &ErrorCollector,
) -> Option<Type> {
if !cls.has_qname("functools", "partial") {
return None;
}
let (CallArg::Arg(func), bound_args) = args.split_first()? else {
return None;
};
let callable = self.callable_signature_for_partial_target(func.infer(self, errors))?;
Some(
self.heap.mk_callable_from(
self.bind_partial_callable(&callable, bound_args, keywords, errors)?,
),
)
}

fn callable_signature_for_partial_target(&self, ty: Type) -> Option<Callable> {
match self.as_call_target(ty) {
CallTargetLookup::Ok(box CallTarget::Callable(TargetWithTParams(None, callable))) => {
Some(callable)
}
CallTargetLookup::Ok(box CallTarget::Function(TargetWithTParams(None, function))) => {
Some(function.signature)
}
CallTargetLookup::Ok(box CallTarget::BoundMethod(
_,
TargetWithTParams(None, function),
)) => Some(function.signature.strip_self_param()),
_ => None,
}
}

fn bind_partial_callable(
&self,
callable: &Callable,
bound_args: &[CallArg],
keywords: &[CallKeyword],
errors: &ErrorCollector,
) -> Option<Callable> {
let Params::List(params) = &callable.params else {
return None;
};
let mut remaining = params.items().to_vec();
let total_bound_positional = bound_args
.iter()
.filter(|arg| matches!(arg, CallArg::Arg(_)))
.count();
for arg in bound_args {
let CallArg::Arg(arg) = arg else {
return None;
};
let matched = if let Some(idx) = remaining.iter().position(|param| {
matches!(
param,
Param::PosOnly(_, _, _) | Param::Pos(_, _, _) | Param::Varargs(_, _)
)
}) {
if !matches!(remaining[idx], Param::Varargs(_, _)) {
remaining.remove(idx);
}
true
} else {
false
};
if !matched {
self.error(
errors,
arg.range(),
ErrorInfo::Kind(ErrorKind::BadArgumentCount),
format!(
"Expected {}, got {}",
count(
callable.arg_counts().positional.max.expect(
"partial only reports extra positional arguments for non-variadic callables",
),
"positional argument",
),
total_bound_positional,
),
);
return None;
}
}
for kw in keywords {
let name = kw.arg.map(|id| &id.id)?;
let mut matched = false;
for idx in 0..remaining.len() {
match &remaining[idx] {
Param::Pos(param_name, _, _) | Param::KwOnly(param_name, _, _)
if param_name == name =>
{
if matches!(remaining[idx], Param::Pos(_, _, _)) {
for later in remaining.iter_mut().skip(idx + 1) {
if let Param::Pos(later_name, later_ty, later_required) = later {
*later = Param::KwOnly(
later_name.clone(),
later_ty.clone(),
later_required.clone(),
);
}
}
}
remaining[idx] = match &remaining[idx] {
Param::Pos(param_name, ty, _) | Param::KwOnly(param_name, ty, _) => {
Param::KwOnly(
param_name.clone(),
ty.clone(),
Required::Optional(None),
)
}
_ => unreachable!(
"matched partial keyword must be positional or keyword-only"
),
};
matched = true;
break;
}
Param::Kwargs(_, _) => {
matched = true;
break;
}
_ => {}
}
}
if !matched {
return None;
}
}
Some(Callable::list(
ParamList::new(remaining),
callable.ret.clone(),
))
}

/// If the class has a registered init capture, extract constructor arg values
/// and wrap the result in `Type::NNModule`. Otherwise return the result as-is.
///
Expand Down
60 changes: 60 additions & 0 deletions pyrefly/lib/test/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,66 @@ zoo(partial(bar, b=99))
"#,
);

testcase!(
test_functools_partial_preserves_remaining_signature,
r#"
from functools import partial

def f(a: int, b: str) -> bool:
return True

g = partial(f, 1)
g("foo", "bar") # E: Expected 1 positional argument, got 2
g(1) # E: Argument `Literal[1]` is not assignable to parameter `b` with type `str`
g("foo")
"#,
);

testcase!(
test_functools_partial_rejects_too_many_bound_args,
r#"
from functools import partial

def f(a: int, b: str, c: int, d: str) -> tuple[int, str]:
return (a + c, b + d)

partial(f, 1, "a", 2, "b", 3, "c", 4, "d") # E: Expected 4 positional arguments, got 8
"#,
);

testcase!(
test_functools_partial_preserves_partial_object_type,
r#"
from collections.abc import Callable
from functools import partial
from typing import Any, assert_type

def f(a: int, b: str) -> bool:
return True

g: partial[bool] = partial(f, 1)
assert_type(g.args, tuple[Any, ...])
assert_type(g.keywords, dict[str, Any])
assert_type(g.func, Callable[..., bool])
g("foo")
"#,
);

testcase!(
test_functools_partial_bound_keyword_remains_overrideable,
r#"
from functools import partial

def f(a: int, b: str) -> bool:
return True

g = partial(f, b="x")
g(1)
g(1, b="y")
g(1, "y") # E: Expected 1 positional argument, got 2
"#,
);

testcase!(
bug = "Self in Metaclass should be treated as Any. Any in metaclass call should act like no annot.",
test_callable_class_substitute_self,
Expand Down
Loading