🌳 Overloading functions in Python

Aug. 2024 Aug. 2024

A useful feature that is available in many languages is the idea of overloading functions.

However, even though it has been proposed to be added in Python since 2007 in PEP 3124, there hasn't been any standard implementation, and while there is a typing.overload feature implemented outlined in PEP 484, it's only for the type checker, and forces you to handle the dispatch manually.

Table of Contents

When is overloading useful?

If you're reading this, you're probably a Python developer, and you might not be sure why overloading functions is useful, so this is my attempt to convince you that function overloading is a neat way to extend functionality of functions resulting in a cleaner API.

Type specialization

Imagine, you're working on an application where the user has several unique fields, and you might have different ways of retrieving the user based on each unique field. You might have something like

def get_user_by_id(id: PrimaryKey): ...
def get_user_by_email(email: Email): ...
def get_user_by_phone(phone_number: PhoneNumber): ...
def get_user_by_ssn(ssn: SSN): ...

This works, but complicates the API. Now you have to import each function, and have to think about which function to use whenever you're trying to get a user. A small but significant win would be to let all the functions have one name, and let Python figure out which function to call based on the call types. We can do this in standard Python by writing our own dispatch function

def get_user(lookup_value: Union[PrimaryKey, Email, PhoneNumber, SSN]):
    match lookup_value:
        case PrimaryKey():
            return get_user_by_id(lookup_value)
        case Email():
            return get_user_by_email(lookup_value)
        case PhoneNumber():
            return get_user_by_phone(lookup_value)
        case SSN():
            return get_user_by_ssn(lookup_value)
        case _:
            raise ValueError("Invalid lookup value")

(note the case matching is possible due to class patterns using isinstance PEP 634)

This is fine, and solves the problems we discussed, but it forces us to define our own dispatch function every time we want to overload a set of functions, and the dispatch functions grow quite quickly in size even with the match-case pattern, and it's another point to maintain every time we want to add another function to overload.

A nicer syntax would be to have something close to what was suggested in PEP 3124.

@overload
def get_user(id: PrimaryKey): ...

@overload
def get_user(email: Email): ...

@overload
def get_user(phone_number: PhoneNumber): ...

@overload
def get_user(ssn: SSN): ...

We see that this way, we get everything we want and more. We don't have to write our own boilerplate dispatch function, the decorator handles it all for us, we don't clutter up our API with multiple function names for what is essentially the same functionality, we don't even have to come up with individual names for the specialized functions, and we keep our functions small.

Overloading can also help make ambiguous return types explicit. Imagine a function

def foo(bar: Union[A, B]) -> Union[C, D]:
    match bar:
        case A():
            return C()
        case B():
            return D()

It's clear that foo only returns type C if the function is called with an input type A, and only returns type D if called with input type B. While it's not difficult to spot in this simple example, in a larger codebase with a more complicated function, it can be difficult to make out the return type, but using overloading, you can make it more explicit.

@overload
def foo(bar: A) -> C: ...

@overload
def foo(bar: B) -> D: ...

Keyword arguments

Another case when function overloading can be useful is for abstracting away control flow for optional keywords.

Consider a video game where you are able to attack a character, but the character is also able to block a portion of the attack, you might have a function like

def attack(target: Character,
        power: PositiveInteger,
        blocked: Optional[PositiveInteger] = None):
    match blocked:
        case None:
            target.health -= power
        case PositiveInteger(blocked):
            if power > blocked:
                target.health -= power - blocked

Here we have two independent branches depending on if the target is able to block part of the attack, or not.

We can abstract this control flow by using function overloading arguably making the code easier to read

@overload
def attack(target: Character, power: PositiveInteger):
    target.health -= power

@overload
def attack(target: Character, power: PositiveInteger, blocked: PositiveInteger):
    if power > blocked:
        target.health -= power - blocked

Now we have two imperative functions succinctly defining the logic. And because we only ever call one of the two functions at a time, we reduce the cognitive load by not having to consider all the other ways it could work as we had to in the control-flow example.

ℹ️

In this specific case, we could have abstracted the control flow away by letting PositiveInteger include $0$, and pass $0$ as blocked whenever the target is unable to block the attack. I still think overloading is more clear as we don't have to worry about the block parameter in situations when it's irrelevant, and even with default arguments, we would still have to know what blocked was doing whenever calling the function.

Moreover, this trick only works due to the nature of the specific types and logic we perform on them. Function overloading is much more general, and can be used in situations where tricks like these do not apply.

Overloading is not always the best strategy for dealing with optional keywords. For example, if the control flow is merely setting a default value for a None keyword, default values in keyword parameters are better suited.

But hopefully, I have now convinced you that overloading is a valuable tool to have in the toolbelt, and can make a lot of situations more clear and easier to reason about.

Next we will look at how to implement this syntax.

Implementation

There are several problems we have to solve when implementing the overload decorator. We need to find a way to keep the functions around when overwriting them, so we later can match the functions' type annotations with the calling arguments to finally dispatch to the right function.

Collecting the functions

By default in Python, when we define a function with the same name as another function in the same namespace, it gets overwritten. But the decorator is called before the function is overwritten, so we can use the globals() dictionary to grab a pointer to the old function before it's overwritten by the new one.

It looks like so

def overload(function) -> Callable:
    # save the old function to a list, so we can avoid overwriting them
    # and access them later during dispatch.
    function.overloads = [function]
    # pointer to the old function before it's overwritten
    old_function = globals().get(function.__name__)

    # if one is true, both should always be true
    # because we are always setting the overloads attribute
    if old_function and hasattr(old_function, "overloads"):
        function.overloads.extend(old_function.overloads)

    # the wraps decorator copies the name and docstring
    # and other special attributes of the original function
    # to the wrapped function
    @functools.wraps(function)
    def wrapper(*args, **kwargs): ...

    # keep the previous functions (overloaded) as we return the wrapper
    # is not updated by wraps.
    wrapper.overloads = function.overloads

    return wrapper

We are creating a list of pointers to all the overloaded functions in function.overloads we start with the function being decorated, and add any previously collected functions from previous instances being overloaded accessed through old_function.overloads. Finally we save the overloads to the wrapped function to make sure that we can access them in the future.

⚠️

Since we are smuggling the previous function pointers using a monkey patched attribute on wrapped, this decorator has to be the decorator closets to the function as other decorators won't be as nice and do the smuggling for us. And since we are getting the name of the function using .__name__, other decorators have to be nice and copy over the attribute for example by using @functools.wraps.

A possible workaround would be to make sure that WRAPPER_ASSIGNMENTS include overloads to make @functools.wraps copy that attribute as well.

Matching the call arguments to the parameter types

Now that we have all the overloaded functions, we need to know which one to call by matching their type annotations with the arguments we get.

We will use a helper function called def match(params, args, kwargs, verbose=False) -> bool: that determines if a function is callable with the parameters given.

The params argument is what we get from inspect.signature(f).parameters. inspect is a built in library that is very useful for inspecting Python objects, and getting information about them. The method we use returns a list of parameters used for calling a function along with the typeannotations.

match will then go through a 3 step process for determining if the function is callable given the args and kwargs.

  1. Check if the number of parameters match. (note we don't attempt to match *args, and **kwargs, positional only or keyword only arguments)
  2. Pair given parameters with expected parameters. That is positional arguments first, then we match the keyword arguments with the remaining function parameters.
  3. We check the type of each pairing.

When we check the type, we cannot use isinstance because it doesn't work with generics-type-annotations. For example:

> from typing import List
> isinstance([1,2,3], List[int])
> TypeError: Subscripted generics cannot be used with class and instance checks

So instead we use trycast's isassignable to check types which works for both classes and generics. A similar library is typeguard which can enforce correct types at runtime.

The implementation of match looks like this:

def match(params, args, kwargs, verbose=False) -> bool:
    if len(params) != len(args) + len(kwargs):
        if verbose:
            print(f"Expected {len(params)} arguments but got {len(args) + len(kwargs)}")
        return False

    # gradually remove parameters that have been matched
    checkable_params = params.copy()

    # first match positional arguments
    for i, (k, v) in enumerate(params.items()):
        if i < len(args):
            # match positional arguments to the first params
            if v.annotation == inspect.Parameter.empty:
                # if the parameter has no annotation, we can't check the type
                # and assume that the type is correct
                checkable_params.pop(k)
                continue
            elif not isassignable(args[i], v.annotation):
                if verbose:
                    print(f"Expected {v.annotation} but got {type(args[i])}")
                return False

            checkable_params.pop(k)
            continue
        # break when we have matched all positional arguments
        break

    # match keyword arguments to the remaining params
    for k, v in kwargs.items():
        matched_param = checkable_params.get(k)
        if matched_param:
            # match keyword arguments to the remaining params
            if matched_param.annotation == inspect.Parameter.empty:
                checkable_params.pop(k)
                continue
            elif not isassignable(v, matched_param.annotation):
                if verbose:
                    print(f"Expected {matched_param.annotation} but got {type(k)}")
                return False

            checkable_params.pop(k)
            continue

        else:
            if verbose:
                print(f"Unexpected keyword argument {k}")
            return False

    if len(checkable_params):
        if verbose:
            print("Still remaining params, ", checkable_params)
        return False
    else:
        # return True only if all parameters have been matched
        return True

Dispatching

Now we have a way of collecting overloaded functions, and a way to determine if a given function is callable, so all we need is to dispatch the call to the right function and forward the return value.

This can all be done in the wrapper which we previously left unimplemented. The implementation is very straightforward.

@wraps(function)
def wrapper(*args, **kwargs):
    if verbose:
        print("Dispatching to one of ", function.overloads)

    # get the parameters of each overloaded function
    params = [inspect.signature(f).parameters.copy() for f in function.overloads]

    # loop through each function candidate and check if it's callable
    # with the given arguments. Eagerly call and return the first callable function.
    for i, (f, p) in enumerate(zip(function.overloads, params)):
        if verbose:
            print(f"Trying overload ({i}) {f} with params {p}")

        if match(params[i], args, kwargs, verbose=verbose):
            return f(*args, **kwargs)

    raise TypeError("None of the overloads matched the given arguments")

And that's it! That's the overload decorator fully implemented!

Now on to

Examples

if we go back to the example of get_user, we can now implement it (using stub classes only implementing string representation) using exactly the same syntax as before, we see that we call the appropriate function based on the type of the input.

@overload
def get_user(email: Email):
    print("Email:", email)
    return email

@overload
def get_user(phone_number: PhoneNumber):
    print("Phone:", phone_number)
    return phone_number

@overload
def get_user(ssn: SSN):
    print("SSN:", ssn)
    return ssn

get_user(Email("[email protected]")) # prints: Email: [email protected]
get_user(PhoneNumber("123-456-789")) # prints: Phone: 123-456-789
get_user(SSN("123-45-6789")) # prints: SSN: 123-45-6789

We can also implement the video game example, for a suitable PositiveInteger type, we can write

class Character():
    def __init__(self):
        self.health = PositiveInteger(100)
        self.max_health = PositiveInteger(100)

    def __str__(self) -> str:
        return f"Character"


@overload
def attack(target: Character, power: PositiveInteger):
    target.health -= power
    print(f"Attacked {target} for {power} damage.")
    print(f"    New health: {target.health}/{target.max_health}.")


@overload
def attack(target: Character, power: PositiveInteger, blocked: PositiveInteger):
    if power > blocked:
        target.health -= power - blocked
        print(f"Attacked {target} for {power - blocked} damage (blocked {blocked}).")
        print(f"    New health: {target.health}/{target.max_health}.")
    else:
        print("Attack was blocked")


attack(target, power=PositiveInteger(10))
# Attacked Character for 10 damage.
#     New health: 90/100.

attack(target, power=PositiveInteger(10), blocked=PositiveInteger(8))
# Attacked Character for 2 damage (blocked 8).
#     New health: 88/100.

attack(target, power=PositiveInteger(10), blocked=PositiveInteger(11))
# Attack was blocked

Code

You can access the code via this gist.

Continue reading

Loading...