Better currying in Python

March 2025 March 2025

A while ago, in 2022, I wrote an article about why currying can be useful, and how to implement it in Python. The parts about why currying is useful still apply. However, the code was very complicated, and used a lot of hacks to keep track of variables. As a result there were a lot of bugs.

A severe bug was that the decorated function could only be used once. Since we were changing the bound parameters directly on the function, once we had run the function once, it would be bound to the specific value in the future with no way of changing it.

This is not very useful.

Meanwhile, I have also learned a lot more about function introspection, so I decided to reimplement the function from scratch making use of function signatures, and bind to manage arguments.

The idea is to use inspect.signature to get the function signature - i.e. to get which args and kwargs the function expects, and the use signature.bind_partial to match the parameters. The bind_partial is much nicer to work with than the Parameter inspection we did before. bind_partial allows you to match values parameters by just passing them as you would a function. The result can then be amended with bind.apply_defaults() which fills out any missing values with the default values from the signature.

Now, we can compare the number of arguments in bind to the total number of arguments expected to have a robust way of determining if we are able to compute the final result, or if we need to curry the result.

If we do need to curry, we have to split the arguments up into positional and keyword arguments, so we can apply subsequent positional arguments. For example

def my_sum(a,b,c):
  return a+b+c

my_sum(1)(2)(3) # should output 6

Another important detail is that in order to be able to use the @partial decorated function multiple times, we need to copy the function before monkey patching attributes to it, so we don't disturb the original function.

In the end, our code ends up being much more concise, and easier to understand.

def partial(f):
  @wraps(f)
  def wrapper(*args, **kwargs):
    sig = inspect.signature(f)
    total_parameters = len(sig.parameters)

    # Initialize stored args if not present
    if not hasattr(f, "_stored_args"):
      f._stored_args = []
    if not hasattr(f, "_stored_kwargs"):
      f._stored_kwargs = {}

    # Combine stored and new arguments
    all_args = f._stored_args + list(args)
    all_kwargs = {**f._stored_kwargs, **kwargs}

    bind = sig.bind_partial(*all_args, **all_kwargs)
    # apply the default values from the signature to fill out 
    # any missing values
    bind.apply_defaults()

    do_currying = len(bind.arguments) < total_parameters
    if do_currying:
      fn = copy_func(f)
      fn._stored_args = all_args
      fn._stored_kwargs = all_kwargs
      return partial(fn)

    return f(*all_args, **all_kwargs)

  return wrapper

The copy_func function is borrowed from Glenn Maynard who posted it on StackOverflow the primary purpose is to make sure that we are not monkey patching the parameters to the original function.

Continue reading

Loading...