A useful feature that is available in many languages is the idea of overloading functions.
However, even though it has been proposed to be added in Python since 2007 in PEP 3124, there hasn't been any standard implementation, and while there is a typing.overload
feature implemented outlined in PEP 484, it's only for the type checker, and forces you to handle the dispatch manually.
Table of Contents
When is overloading useful?
If you're reading this, you're probably a Python developer, and you might not be sure why overloading functions is useful, so this is my attempt to convince you that function overloading is a neat way to extend functionality of functions resulting in a cleaner API.
Type specialization
Imagine, you're working on an application where the user has several unique fields, and you might have different ways of retrieving the user based on each unique field. You might have something like
def get_user_by_id(id: PrimaryKey): ...
def get_user_by_email(email: Email): ...
def get_user_by_phone(phone_number: PhoneNumber): ...
def get_user_by_ssn(ssn: SSN): ...
This works, but complicates the API. Now you have to import each function, and have to think about which function to use whenever you're trying to get a user. A small but significant win would be to let all the functions have one name, and let Python figure out which function to call based on the call types. We can do this in standard Python by writing our own dispatch function
def get_user(lookup_value: Union[PrimaryKey, Email, PhoneNumber, SSN]):
match lookup_value:
case PrimaryKey():
return get_user_by_id(lookup_value)
case Email():
return get_user_by_email(lookup_value)
case PhoneNumber():
return get_user_by_phone(lookup_value)
case SSN():
return get_user_by_ssn(lookup_value)
case _:
raise ValueError("Invalid lookup value")
(note the case matching is possible due to class patterns using isinstance
PEP 634)
This is fine, and solves the problems we discussed, but it forces us to define our own dispatch function every time we want to overload a set of functions, and the dispatch functions grow quite quickly in size even with the match-case pattern, and it's another point to maintain every time we want to add another function to overload.
A nicer syntax would be to have something close to what was suggested in PEP 3124.
@overload
def get_user(id: PrimaryKey): ...
@overload
def get_user(email: Email): ...
@overload
def get_user(phone_number: PhoneNumber): ...
@overload
def get_user(ssn: SSN): ...
We see that this way, we get everything we want and more. We don't have to write our own boilerplate dispatch function, the decorator handles it all for us, we don't clutter up our API with multiple function names for what is essentially the same functionality, we don't even have to come up with individual names for the specialized functions, and we keep our functions small.
Overloading can also help make ambiguous return types explicit. Imagine a function
def foo(bar: Union[A, B]) -> Union[C, D]:
match bar:
case A():
return C()
case B():
return D()
It's clear that foo
only returns type C
if the function is called with an input type A
, and only returns type D
if called with input type B
. While it's not difficult to spot in this simple example, in a larger codebase with a more complicated function, it can be difficult to make out the return type, but using overloading, you can make it more explicit.
@overload
def foo(bar: A) -> C: ...
@overload
def foo(bar: B) -> D: ...
Keyword arguments
Another case when function overloading can be useful is for abstracting away control flow for optional keywords.
Consider a video game where you are able to attack
a character, but the character is also able to block a portion of the attack, you might have a function like
def attack(target: Character,
power: PositiveInteger,
blocked: Optional[PositiveInteger] = None):
match blocked:
case None:
target.health -= power
case PositiveInteger(blocked):
if power > blocked:
target.health -= power - blocked
Here we have two independent branches depending on if the target is able to block part of the attack
, or not.
We can abstract this control flow by using function overloading arguably making the code easier to read
@overload
def attack(target: Character, power: PositiveInteger):
target.health -= power
@overload
def attack(target: Character, power: PositiveInteger, blocked: PositiveInteger):
if power > blocked:
target.health -= power - blocked
Now we have two imperative functions succinctly defining the logic. And because we only ever call one of the two functions at a time, we reduce the cognitive load by not having to consider all the other ways it could work as we had to in the control-flow example.
In this specific case, we could have abstracted the control flow away by letting PositiveInteger include $0$, and pass $0$ as blocked
whenever the target is unable to block the attack. I still think overloading is more clear as we don't have to worry about the block parameter in situations when it's irrelevant, and even with default arguments, we would still have to know what blocked
was doing whenever calling the function.
Moreover, this trick only works due to the nature of the specific types and logic we perform on them. Function overloading is much more general, and can be used in situations where tricks like these do not apply.
Overloading is not always the best strategy for dealing with optional keywords. For example, if the control flow is merely setting a default value for a None
keyword, default values in keyword parameters are better suited.
But hopefully, I have now convinced you that overloading is a valuable tool to have in the toolbelt, and can make a lot of situations more clear and easier to reason about.
Next we will look at how to implement this syntax.
Implementation
There are several problems we have to solve when implementing the overload
decorator. We need to find a way to keep the functions around when overwriting them, so we later can match the functions' type annotations with the calling arguments to finally dispatch to the right function.
Collecting the functions
By default in Python, when we define a function with the same name as another function in the same namespace, it gets overwritten. But the decorator is called before the function is overwritten, so we can use the globals()
dictionary to grab a pointer to the old function before it's overwritten by the new one.
It looks like so
def overload(function) -> Callable:
# save the old function to a list, so we can avoid overwriting them
# and access them later during dispatch.
function.overloads = [function]
# pointer to the old function before it's overwritten
old_function = globals().get(function.__name__)
# if one is true, both should always be true
# because we are always setting the overloads attribute
if old_function and hasattr(old_function, "overloads"):
function.overloads.extend(old_function.overloads)
# the wraps decorator copies the name and docstring
# and other special attributes of the original function
# to the wrapped function
@functools.wraps(function)
def wrapper(*args, **kwargs): ...
# keep the previous functions (overloaded) as we return the wrapper
# is not updated by wraps.
wrapper.overloads = function.overloads
return wrapper
We are creating a list of pointers to all the overloaded functions in function.overloads
we start with the function being decorated, and add any previously collected functions from previous instances being overloaded accessed through old_function.overloads
. Finally we save the overloads to the wrapped function to make sure that we can access them in the future.
Since we are smuggling the previous function pointers using a monkey patched attribute on wrapped
, this decorator has to be the decorator closets to the function as other decorators won't be as nice and do the smuggling for us. And since we are getting the name of the function using .__name__
, other decorators have to be nice and copy over the attribute for example by using @functools.wraps
.
A possible workaround would be to make sure that WRAPPER_ASSIGNMENTS
include overloads
to make @functools.wraps
copy that attribute as well.
Matching the call arguments to the parameter types
Now that we have all the overloaded functions, we need to know which one to call by matching their type annotations with the arguments we get.
We will use a helper function called def match(params, args, kwargs, verbose=False) -> bool:
that determines if a function is callable with the parameters given.
The params
argument is what we get from inspect.signature(f).parameters
. inspect
is a built in library that is very useful for inspecting Python objects, and getting information about them. The method we use returns a list of parameters used for calling a function along with the typeannotations.
match
will then go through a 3 step process for determining if the function is callable given the args
and kwargs
.
- Check if the number of parameters match.
(note we don't attempt to match *args, and **kwargs, positional only or keyword only arguments) - Pair given parameters with expected parameters. That is positional arguments first, then we match the keyword arguments with the remaining function parameters.
- We check the type of each pairing.
When we check the type, we cannot use isinstance
because it doesn't work with generics-type-annotations. For example:
> from typing import List
> isinstance([1,2,3], List[int])
> TypeError: Subscripted generics cannot be used with class and instance checks
So instead we use trycast's isassignable
to check types which works for both classes and generics. A similar library is typeguard which can enforce correct types at runtime.
The implementation of match
looks like this:
def match(params, args, kwargs, verbose=False) -> bool:
if len(params) != len(args) + len(kwargs):
if verbose:
print(f"Expected {len(params)} arguments but got {len(args) + len(kwargs)}")
return False
# gradually remove parameters that have been matched
checkable_params = params.copy()
# first match positional arguments
for i, (k, v) in enumerate(params.items()):
if i < len(args):
# match positional arguments to the first params
if v.annotation == inspect.Parameter.empty:
# if the parameter has no annotation, we can't check the type
# and assume that the type is correct
checkable_params.pop(k)
continue
elif not isassignable(args[i], v.annotation):
if verbose:
print(f"Expected {v.annotation} but got {type(args[i])}")
return False
checkable_params.pop(k)
continue
# break when we have matched all positional arguments
break
# match keyword arguments to the remaining params
for k, v in kwargs.items():
matched_param = checkable_params.get(k)
if matched_param:
# match keyword arguments to the remaining params
if matched_param.annotation == inspect.Parameter.empty:
checkable_params.pop(k)
continue
elif not isassignable(v, matched_param.annotation):
if verbose:
print(f"Expected {matched_param.annotation} but got {type(k)}")
return False
checkable_params.pop(k)
continue
else:
if verbose:
print(f"Unexpected keyword argument {k}")
return False
if len(checkable_params):
if verbose:
print("Still remaining params, ", checkable_params)
return False
else:
# return True only if all parameters have been matched
return True
Dispatching
Now we have a way of collecting overloaded functions, and a way to determine if a given function is callable, so all we need is to dispatch the call to the right function and forward the return value.
This can all be done in the wrapper which we previously left unimplemented. The implementation is very straightforward.
@wraps(function)
def wrapper(*args, **kwargs):
if verbose:
print("Dispatching to one of ", function.overloads)
# get the parameters of each overloaded function
params = [inspect.signature(f).parameters.copy() for f in function.overloads]
# loop through each function candidate and check if it's callable
# with the given arguments. Eagerly call and return the first callable function.
for i, (f, p) in enumerate(zip(function.overloads, params)):
if verbose:
print(f"Trying overload ({i}) {f} with params {p}")
if match(params[i], args, kwargs, verbose=verbose):
return f(*args, **kwargs)
raise TypeError("None of the overloads matched the given arguments")
And that's it! That's the overload
decorator fully implemented!
Now on to
Examples
if we go back to the example of get_user
, we can now implement it (using stub classes only implementing string representation) using exactly the same syntax as before, we see that we call the appropriate function based on the type of the input.
@overload
def get_user(email: Email):
print("Email:", email)
return email
@overload
def get_user(phone_number: PhoneNumber):
print("Phone:", phone_number)
return phone_number
@overload
def get_user(ssn: SSN):
print("SSN:", ssn)
return ssn
get_user(Email("[email protected]")) # prints: Email: [email protected]
get_user(PhoneNumber("123-456-789")) # prints: Phone: 123-456-789
get_user(SSN("123-45-6789")) # prints: SSN: 123-45-6789
We can also implement the video game example, for a suitable PositiveInteger
type, we can write
class Character():
def __init__(self):
self.health = PositiveInteger(100)
self.max_health = PositiveInteger(100)
def __str__(self) -> str:
return f"Character"
@overload
def attack(target: Character, power: PositiveInteger):
target.health -= power
print(f"Attacked {target} for {power} damage.")
print(f" New health: {target.health}/{target.max_health}.")
@overload
def attack(target: Character, power: PositiveInteger, blocked: PositiveInteger):
if power > blocked:
target.health -= power - blocked
print(f"Attacked {target} for {power - blocked} damage (blocked {blocked}).")
print(f" New health: {target.health}/{target.max_health}.")
else:
print("Attack was blocked")
attack(target, power=PositiveInteger(10))
# Attacked Character for 10 damage.
# New health: 90/100.
attack(target, power=PositiveInteger(10), blocked=PositiveInteger(8))
# Attacked Character for 2 damage (blocked 8).
# New health: 88/100.
attack(target, power=PositiveInteger(10), blocked=PositiveInteger(11))
# Attack was blocked
Code
You can access the code via this gist.