Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for decorated overloads #15898

Merged
merged 6 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support decorated overloads
  • Loading branch information
ilevkivskyi committed Aug 17, 2023
commit 70666c95ab19f6cd5eb539aa35b3749fcd98d4b3
91 changes: 59 additions & 32 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,30 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.visit_decorator(defn.items[0])
for fdef in defn.items:
assert isinstance(fdef, Decorator)
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
if defn.is_property:
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
else:
# Perform full check for real overloads to infer type of all decorated
# overload variants.
self.visit_decorator_inner(fdef, allow_empty=True)
if fdef.func.abstract_status in (IS_ABSTRACT, IMPLICITLY_ABSTRACT):
num_abstract += 1
if num_abstract not in (0, len(defn.items)):
self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn)
if defn.impl:
defn.impl.accept(self)
if not defn.is_property:
self.check_overlapping_overloads(defn)
if defn.type is None:
item_types = []
for item in defn.items:
assert isinstance(item, Decorator)
item_type = self.extract_callable_type(item.var.type, item)
if item_type is not None:
item_types.append(item_type)
if item_types:
defn.type = Overloaded(item_types)
# Check override validity after we analyzed current definition.
if defn.info:
found_method_base_classes = self.check_method_override(defn)
if (
Expand All @@ -653,10 +670,35 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.msg.no_overridable_method(defn.name, defn)
self.check_explicit_override_decorator(defn, found_method_base_classes, defn.impl)
self.check_inplace_operator_method(defn)
if not defn.is_property:
self.check_overlapping_overloads(defn)
return None

def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
"""Get type as seen by an overload item caller."""
inner_type = get_proper_type(inner_type)
outer_type: CallableType | None = None
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return outer_type

def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# At this point we should have set the impl already, and all remaining
# items are decorators
Expand All @@ -680,40 +722,20 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

# This can happen if we've got an overload with a different
# decorator or if the implementation is untyped -- we gave up on the types.
inner_type = get_proper_type(inner_type)
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
impl_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=defn.impl,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
impl_type = inner_call
if impl_type is None:
self.msg.not_callable(inner_type, defn.impl)
impl_type = self.extract_callable_type(inner_type, defn.impl)

is_descriptor_get = defn.info and defn.name == "__get__"
for i, item in enumerate(defn.items):
# TODO overloads involving decorators
assert isinstance(item, Decorator)
sig1 = self.function_type(item.func)
assert isinstance(sig1, CallableType)
sig1 = self.extract_callable_type(item.var.type, item)
if sig1 is None:
continue

for j, item2 in enumerate(defn.items[i + 1 :]):
assert isinstance(item2, Decorator)
sig2 = self.function_type(item2.func)
assert isinstance(sig2, CallableType)
sig2 = self.extract_callable_type(item2.var.type, item2)
if sig2 is None:
continue

if not are_argument_counts_overlapping(sig1, sig2):
continue
Expand Down Expand Up @@ -4751,17 +4773,20 @@ def visit_decorator(self, e: Decorator) -> None:
e.var.type = AnyType(TypeOfAny.special_form)
e.var.is_ready = True
return
self.visit_decorator_inner(e)

def visit_decorator_inner(self, e: Decorator, allow_empty: bool = False) -> None:
if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
self.check_func_item(e.func, name=e.func.name)
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)

# Process decorators from the inside out to determine decorated signature, which
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, OVERLOAD_NAMES):
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
if not allow_empty:
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig, context=e)
Expand All @@ -4788,6 +4813,8 @@ def visit_decorator(self, e: Decorator) -> None:
self.msg.fail("Too many arguments for property", e)
self.check_incompatible_property_override(e)
# For overloaded functions we already checked override for overload as a whole.
if allow_empty:
return
if e.func.info and not e.func.is_dynamic() and not e.is_overload:
found_method_base_classes = self.check_method_override(e)
if (
Expand Down
72 changes: 66 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,13 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
elif isinstance(node, FuncDef):
# Reference to a global function.
result = function_type(node, self.named_type("builtins.function"))
elif isinstance(node, OverloadedFuncDef) and node.type is not None:
# node.type is None when there are multiple definitions of a function
# and it's decorated by something that is not typing.overload
# TODO: use a dummy Overloaded instead of AnyType in this case
# like we do in mypy.types.function_type()?
result = node.type
elif isinstance(node, OverloadedFuncDef):
if node.type is None:
if self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(node.name, e)
result = AnyType(TypeOfAny.from_error)
else:
result = node.type
elif isinstance(node, TypeInfo):
# Reference to a type object.
if node.typeddict_type:
Expand Down Expand Up @@ -1337,6 +1338,56 @@ def transform_callee_type(

return callee

def is_generic_decorator_overload_call(
self, callee_type: ProperType, args: list[Expression]
) -> Overloaded | None:
"""Check if this looks like an application of a generic function to overload argument."""
if not isinstance(callee_type, CallableType) or not callee_type.variables:
return None
if len(callee_type.arg_types) != 1 or len(args) != 1:
# TODO: can we handle more general cases?
return None
if not isinstance(get_proper_type(callee_type.arg_types[0]), CallableType):
return None
if not isinstance(get_proper_type(callee_type.ret_type), CallableType):
return None
with self.chk.local_type_map():
with self.msg.filter_errors():
arg_type = get_proper_type(self.accept(args[0], type_context=None))
if isinstance(arg_type, Overloaded):
return arg_type
return None

def handle_decorator_overload_call(
self, callee_type: CallableType, overloaded: Overloaded, ctx: Context
) -> tuple[Type, Type] | None:
"""Type-check application of a generic callable to an overload.

We check call on each individual overload item, and then combine results into a new
overload. This function should be only used if callee_type takes and returns a Callable.
"""
result = []
inferred_args = []
for item in overloaded.items:
arg = TempNode(typ=item)
with self.msg.filter_errors() as err:
item_result, inferred_arg = self.check_call(callee_type, [arg], [ARG_POS], ctx)
if err.has_new_errors():
# This overload doesn't match.
continue
p_item_result = get_proper_type(item_result)
if not isinstance(p_item_result, CallableType):
continue
p_inferred_arg = get_proper_type(inferred_arg)
if not isinstance(p_inferred_arg, CallableType):
continue
inferred_args.append(p_inferred_arg)
result.append(p_item_result)
if not result or not inferred_args:
# None of the overload matched (or overload was initially malformed).
return None
return Overloaded(result), Overloaded(inferred_args)

def check_call_expr_with_callee_type(
self,
callee_type: Type,
Expand Down Expand Up @@ -1450,6 +1501,15 @@ def check_call(
"""
callee = get_proper_type(callee)

overloaded = self.is_generic_decorator_overload_call(callee, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: check_call is pretty hot, should we move this inside if isinstance(callee, CallableType): on L1513?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea!

if overloaded is not None:
# Special casing for inline application of generic callables to overloads.
# Supporting general case would be tricky, but this should cover 95% of cases.
assert isinstance(callee, CallableType)
overloaded_result = self.handle_decorator_overload_call(callee, overloaded, context)
if overloaded_result is not None:
return overloaded_result

if isinstance(callee, CallableType):
return self.check_callable_call(
callee,
Expand Down
12 changes: 11 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,17 @@ def analyze_instance_member_access(
return analyze_var(name, first_item.var, typ, info, mx)
if mx.is_lvalue:
mx.msg.cant_assign_to_method(mx.context)
signature = function_type(method, mx.named_type("builtins.function"))
if not isinstance(method, OverloadedFuncDef):
signature = function_type(method, mx.named_type("builtins.function"))
else:
if method.type is None:
# Overloads may be not ready if they are decorated. Handle this in same
# manner as we would handle a regular decorated function: defer if possible.
if not mx.no_deferral:
mx.not_ready_callback(method.name, mx.context)
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
Expand Down
4 changes: 3 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,9 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
elif not non_overload_indexes:
self.handle_missing_overload_implementation(defn)

if types:
if types and (not isinstance(defn.impl, Decorator) or not defn.impl.decorators):
# TODO: should we enforce decorated overloads consistency somehow?
# TODO: how do support decorated overloads in stubs without major slow-down?
defn.type = Overloaded(types)
defn.type.line = defn.line

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3062,10 +3062,10 @@ def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]:
reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]"
reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]"
reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`6) -> S`6"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`8) -> S`8"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`9) -> S`9"
reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`16) -> builtins.list[S`16]"
dec4(lambda x: x) # E: Incompatible return value type (got "S", expected "List[object]")
[builtins fixtures/list.pyi]

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-newsemanal.test
Original file line number Diff line number Diff line change
Expand Up @@ -3208,7 +3208,7 @@ class User:

def __init__(self, name: str) -> None:
self.name = name # E: Cannot assign to a method \
# E: Incompatible types in assignment (expression has type "str", variable has type "Callable[..., Any]")
# E: Cannot determine type of "name"

[case testNewAnalyzerMemberNameMatchesTypedDict]
from typing import Union, Any
Expand Down
34 changes: 32 additions & 2 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ f(0)
@overload # E: Name "overload" is not defined
def g(a:int): pass
def g(a): pass # E: Name "g" already defined on line 9
g(0)
g(0) # E: Cannot determine type of "g"

@something # E: Name "something" is not defined
def r(a:int): pass
def r(a): pass # E: Name "r" already defined on line 14
r(0)
r(0) # E: Cannot determine type of "r"
[out]
main:2: error: Name "overload" is not defined
main:4: error: Name "f" already defined on line 2
main:4: error: Name "overload" is not defined
main:6: error: Name "f" already defined on line 2
main:7: error: Cannot determine type of "f"

[case testTypeCheckOverloadWithImplementation]
from typing import overload, Any
Expand Down Expand Up @@ -5226,6 +5227,7 @@ def func(x):
[out]
tmp/lib.pyi:1: error: Name "overload" is not defined
tmp/lib.pyi:4: error: Name "func" already defined on line 1
main:2: error: Cannot determine type of "func"
main:2: note: Revealed type is "Any"

-- Order of errors is different
Expand All @@ -5242,6 +5244,7 @@ def func(x: str) -> str: ...
tmp/lib.pyi:1: error: Name "overload" is not defined
tmp/lib.pyi:3: error: Name "func" already defined on line 1
tmp/lib.pyi:3: error: Name "overload" is not defined
main:3: error: Cannot determine type of "func"
main:3: note: Revealed type is "Any"

[case testLiteralSubtypeOverlap]
Expand Down Expand Up @@ -6613,3 +6616,30 @@ def struct(__cols: Union[List[S], Tuple[S, ...]]) -> int: ...
def struct(*cols: Union[S, Union[List[S], Tuple[S, ...]]]) -> int:
pass
[builtins fixtures/tuple.pyi]

[case testRegularGenericDecoratorOverload]
from typing import Callable, overload, TypeVar, List

S = TypeVar("S")
T = TypeVar("T")
def transform(func: Callable[[S], List[T]]) -> Callable[[S], T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
28 changes: 28 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1646,3 +1646,31 @@ def bar(b: B[P]) -> A[Concatenate[int, P]]:
# N: Got: \
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecDecoratorOverload]
from typing import Callable, overload, TypeVar, List
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")
def transform(func: Callable[P, List[T]]) -> Callable[P, T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
Loading