Currying: Partial Functions in Python

One of the cool things about functional programming languages such as Haskell is the concept of partially evaluating a function also known as currying.

It serves as a way of encapsulating information hiding details from the caller. This is a powerful way of dealing with abstraction. We will talk more about this later.

Note: I will use currying, binding, and partial application interchangeably in this article. There are some differences, but those are outside the scope of this article.

Mathematically it is like if we have a function $f(a,b,c)$ then we can bind aa by calling $f(a)$ giving us a new function $f(b,c)$. Many functional programming languages provide a neat syntax involving just calling the function. For example in Haskell, you can write

f a

which returns a new function that takes two parameters b,cb,c.

However, if we try to do the same in Python, we get an error telling us we didn't specify the remaining arguments.

>>> f = lambda a,b,c: a+b+c
>>> f(5)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: f() missing 2 required positional arguments: 'b' and 'c'

There are a couple of different ways we can curry functions in python. The most obvious way is probably to define a new function using lambda.

lambda b,c: f(5,b,c)

However, while this works it can be quite verbose. There must be a better way, and indeed there is. We can use partial from the functools module.

from functools import partial
partial(f,5)

We can even bind multiple arguments at the same time, and use keyword arguments.

f = lambda a,b,c: a+b+c
partial(f,5,4)(3) # 12
partial(f,5,c=4)(3) # 12

This is much better, but it still bugs me that the syntax isn't as elegant as Haskell's just calling the function.

Why can't we have $f(b=5)$?

It turns out we can if we can accept decorating f slightly - which isn't a big deal as we usually want to use it on our own functions, and it's not more difficult wrap other functions than to use functools.partial.

For starters, let us simplify the problem a bit. Haskell doesn't support keyword arguments. You have to bind arguments in order starting with the first one. If we impose the same constraints on our solution, we can do something like this

def partial(f,nargs=None):
  """ Does not support kwargs """
  def wrapper(*args):
    total_paramters = wrapper.nargs or f.__code__.co_argcount
    given_parameters = len(args)
    do_currying = total_paramters > given_parameters

    if do_currying:
      def c(*newargs):
        return f(*(args+newargs))

      return partial(c,nargs=total_paramters-len(args))

    return f(*args)

  wrapper.nargs = nargs
  return wrapper

The decorator works by checking if the given number of arguments is equal to the total number of arguments expected for the function. Then recursively constructs a new function that expects the remaining number of arguments until the given arguments matches the total arguments for the first function. This works because Python is a dynamic language, and we can access the function arguments at runtime.

We can now use the partial decorator on a function to enable currying for that function.

@partial
def multiply(x,y):
  return x*y

multiply(1,2)  # 2
multiply(1)(2) # 2

@partial
def ssum(a,b,c):
  return 100*a+10*b+c

ssum(1,2,3)   # 123
ssum(1)(2)(3) # 123
ssum(1,2)(3)  # 123

This is already quite neat. But it suffers from the same limitations as Haskell, namely, we cannot do f(b=3) on f(a,b,c) and get a f(a,c) function out. The Haskell community has a few convoluted answers for how to get around this limitation, but generally recommends avoiding this pattern.

But I think it's sad to avoid such a flexible and powerful pattern, so how can we extend the solution to include keywords in Python?

After some trial and error, I came up with this solution.

def partial(f, nargs=None):
  def wrapper(*args,**kwargs):
    total_paramters = wrapper.nargs or f.__code__.co_argcount
    given_parameters = len(args) + len(kwargs)

    do_currying = total_paramters > given_parameters

    # convert args to kwargs.
    combined_kwargs = kwargs

    # remove keyword arguments from the arglist
    # and let the args fill in the 'gaps'.
    remaining_var_names = [var for var in f.varnames if not var in kwargs]

    # fill in the gaps using the args
    # keep a list of unfilled_args for later currying.
    unfilled_args = [x for x in remaining_var_names]
    for i, value in enumerate(args):
      key = remaining_var_names[i]
      combined_kwargs[key] = value
      unfilled_args.remove(key)

    if do_currying:
      def c(**newkwargs):
        return f(**{**combined_kwargs, **newkwargs})
      c.varnames = unfilled_args
      return partial(c, nargs=total_paramters-len(kwargs))

    return f(**kwargs)

  if not hasattr(f,"varnames"):
    f.varnames = f.__code__.co_varnames

  wrapper.nargs = nargs
  return wrapper

It's not the most pretty solution, but it works. It works pretty much in the same way as the previous solution with the recursively more and more specific functions, but this time we convert all the arguments to keyword arguments before passing it to the function, and we have some extra plumming to keep track of which argument names we still need to fill not just how many arguments.

Going back to our test functions, this lets us partially bind like this

@partial
def multiply(x,y):
  return x*y

multiply(1,2)
multiply(1)(2)
multiply(y=1)(2)
multiply(1)(y=2)
multiply(x=1)(y=2)
multiply(x=1)(2)

@partial
def ssum(a,b,c):
  return 100*a+10*b+c

ssum(1,2,3)
ssum(1)(2)(3)
ssum(1,2)(3)
ssum(1)(c=3)(2)
ssum(c=3)(b=2)(1)
ssum(1)(2)(c=3)

which is pretty neat.

Why is currying useful?

Sure, you might say, it's neat that we can call functions and get a partially applied function back, but how is this useful?

As I alluded to in the beginning, partial application can be a powerful tool handling abstractions. For example let us consider an example of dependency injection.

Suppose we have a web application, we would probably have a function such as get_user(id), but we would have different ways we might get the customer, for example from a database, a cache, or maybe we are working with mock data in memory. We would typically implement this in Python by defining an abstract base class or interface which we inherit from or implement for each of the different storage methods.

With partial application, we have another option where we define a function for each of the different storage methods, and then bind the paramters that are required for the method creating an agnostic get_user(id) function for the rest of the application to use.

@partial
def get_user_from_db(dbconnection, id):
  ...

@partial
def get_user_from_memory(dict, id):
  ...

@partial
def get_user_from_cache(cache, id):
  ...

# specify a dependency agnostic function to use in the rest of the application. 
get_user = get_user_from_db(conn)

Generally, partial application is analogous to inheritance from a general class to a more specialized class.

It can also be used in a scenario where you pass a function around different parts of the system configuring the parameters of the function without the systems knowing about each other.

Moreover, it can be used as a way of achieving pseudo* lazy evaluation where the function value is only evaluated once you have gathered all the parameters. *pseudo because the parameters are evaluated eagerly as per the Python interpreter.

Even if you don't adopt functional programming patterns, or use it for the other situational benefits, it can still give increased flexibility when manipulating functions with very little overhead.

And by using the decorator style, the syntax becomes elegant and succinct enough that it is easy to read, and not cumbersome to use.

Update

I have updated the implementation to be a bit more robust using some built in libraries for function introspection instead of finding the arguments manually.

from functools import wraps
import inspect
from typing import List, Callable

def partial(f):
  """A wrapper for a user-defined function that allows for currying.

  Example:

  @partial
  def foo(a,b,c):
    return a+b+c

  > foo(b=1)(1)(1) == 1
  > True

  Args:
      f (Callable): function to curry

  Returns:
      Callabe: The function itself with any arguments that were passed in curried.
  """

  @wraps(f)
  def wrapper(*args,**kwargs):
    required_parameters = len([0 for _, param in f._pargs.items() if param.default == inspect.Parameter.empty])
    given_parameters = len(args) + len(kwargs)

    do_currying = required_parameters > given_parameters

    # update positional arguments
    for i,arg in enumerate(args):
      try:
        key = next(k for k, v in f._pargs.items() if v.default == inspect.Parameter.empty)
        f._pargs[key] = inspect.Parameter(key, inspect.Parameter.KEYWORD_ONLY, default=arg)
      except StopIteration:
        # handle *args parameter
        if hasattr(f,"_args"):
          f._args.extend(args[i:])
          break
        else:
          raise TypeError(f"{f.__name__}() takes {required_parameters} positional arguments but {given_parameters} were given")


    # update keyword arguments
    for key, value in kwargs.items():
      #handle **kwargs parameter
      if key not in f._pargs:
        if hasattr(f,"_kwargs"):
          f._kwargs[key] = value
        else:
          raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{key}'")

      f._pargs[key] = inspect.Parameter(key, inspect.Parameter.KEYWORD_ONLY, default=value)

    if do_currying:
      # do I even need c here, or can I just return partial(f)?
      c = lambda *args, **kwargs: f(*args, **kwargs)
      c._pargs = f._pargs
      if hasattr(f,"_args"): c._args = f._args
      if hasattr(f,"_kwargs"): c._kwargs = f._kwargs

      # copy over attributes not handled by wraps
      c.__doc__ = f.__doc__
      c.__name__ = f.__name__

      return partial(c)

    if not hasattr(f,"_args"): f._args = []
    if not hasattr(f,"_kwargs"): f._kwargs = {}

    # TODO: find a way to apply args to the function.
    return f(**{**{key: value.default for key, value in f._pargs.items()}, **f._kwargs})

  if not hasattr(f,"_pargs"):
    sig = inspect.signature(f)
    f._pargs = sig.parameters.copy()

    if f._pargs.get("args"):
      raise ValueError("We currently don't support *args in function as we convert all arguments to keyword only arguments.")
      f._pargs.pop("args")
      f._args = []
    if f._pargs.get("kwargs"):
      f._pargs.pop("kwargs")
      f._kwargs = {}

  return wrapper

Continue reading

Loading...