A while ago, in 2022, I wrote an article about why currying can be useful, and how to implement it in Python. The parts about why currying is useful still apply. However, the code was very complicated, and used a lot of hacks to keep track of variables. As a result there were a lot of bugs.
A severe bug was that the decorated function could only be used once. Since we were changing the bound parameters directly on the function, once we had run the function once, it would be bound to the specific value in the future with no way of changing it.
This is not very useful.
Meanwhile, I have also learned a lot more about function introspection, so I decided to reimplement the function from scratch making use of function signatures
, and bind
to manage arguments.
The idea is to use inspect.signature
to get the function signature - i.e. to get which args
and kwargs
the function expects, and the use signature.bind_partial
to match the parameters. The bind_partial
is much nicer to work with than the Parameter
inspection we did before. bind_partial
allows you to match values parameters by just passing them as you would a function. The result can then be amended with bind.apply_defaults()
which fills out any missing values with the default values from the signature.
Now, we can compare the number of arguments in bind
to the total number of arguments expected to have a robust way of determining if we are able to compute the final result, or if we need to curry the result.
If we do need to curry, we have to split the arguments up into positional and keyword arguments, so we can apply subsequent positional arguments. For example
def my_sum(a,b,c):
return a+b+c
my_sum(1)(2)(3) # should output 6
Another important detail is that in order to be able to use the @partial
decorated function multiple times, we need to copy the function before monkey patching attributes to it, so we don't disturb the original function.
In the end, our code ends up being much more concise, and easier to understand.
def partial(f):
@wraps(f)
def wrapper(*args, **kwargs):
sig = inspect.signature(f)
total_parameters = len(sig.parameters)
# Initialize stored args if not present
if not hasattr(f, "_stored_args"):
f._stored_args = []
if not hasattr(f, "_stored_kwargs"):
f._stored_kwargs = {}
# Combine stored and new arguments
all_args = f._stored_args + list(args)
all_kwargs = {**f._stored_kwargs, **kwargs}
bind = sig.bind_partial(*all_args, **all_kwargs)
# apply the default values from the signature to fill out
# any missing values
bind.apply_defaults()
do_currying = len(bind.arguments) < total_parameters
if do_currying:
fn = copy_func(f)
fn._stored_args = all_args
fn._stored_kwargs = all_kwargs
return partial(fn)
return f(*all_args, **all_kwargs)
return wrapper
The copy_func
function is borrowed from Glenn Maynard who posted it on StackOverflow the primary purpose is to make sure that we are not monkey patching the parameters to the original function.