Caching and Decorators
Instrumentation
Consider the lowly Fibonacci function:
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib(10)
55
This function is a straightforward translation of
fib
is a recursive function. A single external call to fib
can result in lots of recursive calls, where it calls itself.
Let’s instrument fib
, to report how many calls result from a single external call:
count = 0
def fib(n):
global count
count += 1
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 109 times
The number of calls increases (exponentially) as a function of the argument:
count = 0
def fib(n):
global count
count += 1
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
for i in [1, 2, 3, 4, 5, 10, 20, 30]:
count = 0
print("fib({}) = {}; fib was called {:,} times".format(i, fib(i), count))
fib(1) = 1; fib was called 1 times
fib(2) = 1; fib was called 1 times
fib(3) = 2; fib was called 3 times
fib(4) = 3; fib was called 5 times
fib(5) = 5; fib was called 9 times
fib(10) = 55; fib was called 109 times
fib(20) = 6765; fib was called 13,529 times
fib(30) = 832040; fib was called 1,664,079 times
The fact that it calls itself so many times has a direct effect on performance:
%timeit fib(10)
%timeit fib(20)
%timeit fib(30)
13.2 µs ± 684 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.74 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
192 ms ± 7.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Note the microseconds (µs) and milliseconds (ms) per loop. 200 ms is 1/5 of a second, so it’s getting pretty slow.
Let’s instrument the function to print each time it’s called. This way we that we can see more about what’s going on.
def fib(n):
print("Calling fib({})".format(n))
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib(6)
Calling fib(6)
Calling fib(4)
Calling fib(2)
Calling fib(3)
Calling fib(1)
Calling fib(2)
Calling fib(5)
Calling fib(3)
Calling fib(1)
Calling fib(2)
Calling fib(4)
Calling fib(2)
Calling fib(3)
Calling fib(1)
Calling fib(2)
8
This output shows that fib
is called multiple times with the same arguments. For example, fib(4)
is called a couple times. fib(3)
is being called three times. fib(2)
is called five times.
The modified instrumentation below report how many times fib
is applied to each argument value. (This is also a chance to learn about defaultdict.)
from collections import defaultdict
counts = defaultdict(lambda: 0) # Default values will be 0
def fib(n):
counts[n] += 1
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib(6)
for n, count in sorted(counts.items()):
print("fib({}) was called {} times".format(n, count))
fib(1) was called 3 times
fib(2) was called 5 times
fib(3) was called 3 times
fib(4) was called 2 times
fib(5) was called 1 times
fib(6) was called 1 times
A Digression
You might notice an interesting pattern in the sequence of call counts, above. Reading from the bottom (fib(6)
), and skipping fib(1)
, the number of times fib
is called with each argument value from $6$ down to $2$ is $1, 1, 2, 3, 5$.
Let’s see if this holds up:
counts = defaultdict(lambda: 0)
fib(12)
for n, count in sorted(counts.items()):
print("fib({}) was called {} times".format(n, count))
fib(1) was called 55 times
fib(2) was called 89 times
fib(3) was called 55 times
fib(4) was called 34 times
fib(5) was called 21 times
fib(6) was called 13 times
fib(7) was called 8 times
fib(8) was called 5 times
fib(9) was called 3 times
fib(10) was called 2 times
fib(11) was called 1 times
fib(12) was called 1 times
Note the sequence: fib
applied to $12, 11, 10, \ldots, 4, 3, 2$ is $1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89$. What does this sequence remind you of? Harder: why does it happen?
Caching
Instead of instrumenting the code, to simply record how many times the function was called, we can modify it to cache the computation, and use this cached value.
You can use the same technique to cache web requests in the Text Mining project. There it avoides repeated network requests. Here it saves repeated computation. This reduces the function’s computational complexity.
fib_cache = {}
count = 0
def fib(n):
global count
if n in fib_cache:
return fib_cache[n]
count += 1
if n <= 2:
return 1
else:
result = fib(n-2) + fib(n-1)
fib_cache[n] = result
return result
print("fib({}) = {}; fib was computed {} times".format(10, fib(10), count))
fib(10) = 55; fib was computed 11 times
Compare this to the uncached fib
, where fib(10)
resulted in 109 calls to fib
.
Savings are (exponentially) greater for greater values of $n$.
Let’s re-run the timings. We’ll time a wrapper for the cached fib
, that resets fib
’s cache each time.
def rfib(n):
global fib_cache
fib_cache = {}
return fib(n)
%timeit rfib(10)
%timeit rfib(20)
%timeit rfib(30)
2.64 µs ± 43.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
6.49 µs ± 551 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.7 µs ± 346 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
For reference, the un-cached timings looked like this:
13.2 µs ± 684 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.74 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
192 ms ± 7.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Separating concerns
The code above takes a step forwards in terms of performance, but backwards in legibility. Half of it is concerned with doing the math, and half of it is concerned with caching. It violated separation of concerns.
Let’s separate the first instrumented fib
, that counts how often it’s been called, into two functions. The inner function, fib_
, does the computation. The outer function, fib
, wraps fib_
. It adds the instrumentation.
(The underscore in fib_
is used in Python the same way a prime $’$ is used in math. fib_
corresponds to $\textrm{fib}’$ or $F’$.)
count = 0
def fib_(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
def fib(n):
global count
count += 1
return fib_(n)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 109 times
We can change the outer function to add different instrumentation. We don’t need to touch the inner function.
from collections import defaultdict
counts = defaultdict(lambda: 0)
def fib(n):
global count
counts[n] += 1
return fib_(n)
fib(10)
for n, count in sorted(counts.items()):
print("fib({}) was called {} times".format(n, count))
fib(1) was called 21 times
fib(2) was called 34 times
fib(3) was called 21 times
fib(4) was called 13 times
fib(5) was called 8 times
fib(6) was called 5 times
fib(7) was called 3 times
fib(8) was called 2 times
fib(9) was called 1 times
fib(10) was called 1 times
Now let’s define a different outer function, that adds caching. (We’ll also keep some instrumentation, so we can see that the cache is working.) It can use the same inner function.
count = 0
fib_cache = {}
def fib(n):
if n in fib_cache:
return fib_cache[n]
global count
count += 1
result = fib_(n)
fib_cache[n] = result
return result
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 10 times
Higher-Order Programming
The various out functions above – all named fib
– didn’t need to know anything about fib_
specifically. The name fib_
was included in their definitions, but they otherwise could have wrapped any other (unary) function instead.
Let’s extract fib_
from the function definition, and instead supply it as a parameter:
count = 0
def fib_(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
fib = counting(fib_)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 109 times
Above, def fib_
defined fib
, counting(fib)
used that value, and fib = …
(on the same line) defined fib
.
We don’t need the value of fib_
after we’ve used it as an argument to counting
. We can therefore def fib
instead of def fib_
:
count = 0
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
fib = counting(fib)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 109 times
This relieves us from have to come up separate names for fib_
and fib
.
It also eliminates the inelegance where fib_
had to know to call fib
(that hasn’t been defined yet).
Previously, in order to recognize that fib_
was actually recursive, we needed to know that fib_
was going to be wrapped up and that the wrapped function would be called fib
.
Now, you can read the functionality of fib
straight from its definition. As a bonus, if we comment out the fib = counting(fib)
line, fib
still works – it just isn’t instrumented.
We can wrap a function multiple times. fib
below has been wrapped by a function (counting
) that adds counting instrumentation ,and then wrapped in a function (cached
) that adds caching. The final value of fib
is a function that both records how many times it’s been called, and is cached.
count = 0
def cached(fn):
cache = {}
def wrapper(n):
if n in cache:
return cache[n]
result = fn(n)
cache[n] = result
return result
return wrapper
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib = counting(fib)
fib = cached(fib)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 10 times
Let’s write one more wrapping function, for even more instrumentation:
count = 0
def cached(fn):
cache = {}
def wrapper(n):
if n in cache:
return cache[n]
result = fn(n)
cache[n] = result
return result
return wrapper
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
def traced(fn):
def wrapper(n):
print('called {}({})'.format(fn.__name__, n))
return fn(n)
return wrapper
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib = traced(fib)
fib = counting(fib)
fib = cached(fib)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
called fib(10)
called fib(8)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(5)
called fib(7)
called fib(9)
fib(10) = 55; fib was called 10 times
Order matters! Above, the tracing wrapper is inside the wrapped the caching wrapper. Below, the tracing happens outside the cache.
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib = cached(fib)
fib = counting(fib)
fib = traced(fib)
count = 0
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
called wrapper(10)
called wrapper(8)
called wrapper(6)
called wrapper(4)
called wrapper(2)
called wrapper(3)
called wrapper(1)
called wrapper(2)
called wrapper(5)
called wrapper(3)
called wrapper(4)
called wrapper(7)
called wrapper(5)
called wrapper(6)
called wrapper(9)
called wrapper(7)
called wrapper(8)
fib(10) = 55; fib was called 17 times
The functions that do the wrapping can be applied to any (unary) function. Here, we’ll apply it to a version of the exponentiation function that’s been modified to take a single argument: a tuple of $(\textrm{base}, \textrm{exp})$: pow((b, e))
$= b^e$.
(The forthcoming Appendix will show how to modify these wrappers for use on functions that take different numbers of arguments.)
def pow(base_and_exp):
base, exp = base_and_exp
if exp == 0:
return 1
if exp == 1:
return base
half = exp // 2
return pow((base, half)) * pow((base, exp - half))
pow = traced(pow)
pow = cached(pow)
print('exp({}, {}) = {}'.format(2, 15, pow((2, 15))))
called pow((2, 15))
called pow((2, 7))
called pow((2, 3))
called pow((2, 1))
called pow((2, 2))
called pow((2, 4))
called pow((2, 8))
exp(2, 15) = 32768
Terminology
A function that takes a function as an argument and returns a function as a value, is called a higher-order function, functor, or functional.
Programming with higher-order functions is higher-order programming.
“Functor” and “functional” are also used in math, for meanings that aren’t that different. For example, the (indefinite) integral $\int$ takes an function $x \rightarrow x^2 dx$ as an argument and returns another function $x \rightarrow 3/2 x^3$ as a value: $\int x^2 dx = 3/2 x^3$.
Decorator Syntax
fib = counting(fib)
after fib
has been defined, is equivalent to @counting
written immediately before the function definition.
This is the Python decorator construct.
count = 0
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
@counting
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 109 times
count = 0
def cached(fn):
cache = {}
def wrapper(n):
if n in cache:
return cache[n]
result = fn(n)
cache[n] = result
return result
return wrapper
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
@cached
@counting
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
fib(10) = 55; fib was called 10 times
count = 0
def cached(fn):
cache = {}
def wrapper(n):
if n in cache:
return cache[n]
result = fn(n)
cache[n] = result
return result
return wrapper
def counting(fn):
def wrapper(n):
global count
count += 1
return fn(n)
return wrapper
def traced(fn):
def wrapper(n):
print('called {}({})'.format(fn.__name__, n))
return fn(n)
return wrapper
@cached
@counting
@traced
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
called fib(10)
called fib(8)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(5)
called fib(7)
called fib(9)
fib(10) = 55; fib was called 10 times
Caching: the built-in option
Since Python comes with “batteries included”, it has a library called functools
dedicated to higher-order functions.
This library includes a decorator called lru_cache
that behaves similarly to the cached
decorator we wrote. The name comes from the fact that when we run out of space in the cache, it is the Least Recently Used entries that are evicted.
from functools import lru_cache
@lru_cache(maxsize=None)
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
print("fib({}) = {}".format(10, fib(10)))
print(fib.cache_info())
fib(10) = 55
CacheInfo(hits=7, misses=10, maxsize=None, currsize=10)
Additional Reading
Glossary:
- Python decorators
- Separation of concerns
- Functional programming
- Higher-order function
- Higher-order programming
Tutorials:
- Primer on Python decorators, Real Python.
- A guide to Python’s function decorators, The Code Ship, Ayman Farhat.
- Decorators, Python 3 Patterns, Recipes and Idioms.
Appendix: Variadic Decorators
The following higher-order functions apply to functions that take any number of arguments.
count = 0
def cached(fn):
cache = {}
def wrapper(*args):
if n in cache:
return cache[args]
result = fn(*args)
cache[args] = result
return result
return wrapper
def counting(fn):
def wrapper(*args):
global count
count += 1
return fn(*args)
return wrapper
def traced(fn):
def wrapper(*args):
print('called {}({})'.format(fn.__name__, ', '.join(map(str, args))))
return fn(*args)
return wrapper
They can be used on fib
, which has a single parameter:
def fib(n):
if n <= 2:
return 1
else:
return fib(n-2) + fib(n-1)
fib = traced(fib)
fib = counting(fib)
fib = cached(fib)
print("fib({}) = {}; fib was called {:,} times".format(10, fib(10), count))
called fib(10)
called fib(8)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(7)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(9)
called fib(7)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(8)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(7)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(6)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
called fib(5)
called fib(3)
called fib(1)
called fib(2)
called fib(4)
called fib(2)
called fib(3)
called fib(1)
called fib(2)
fib(10) = 55; fib was called 109 times
But unlike the original higher-order functions, they can also be used with a more natural definition of pow
, that has two parameters:
def pow(base, exp):
if exp == 0:
return 1
if exp == 1:
return base
half = exp // 2
return pow(base, half) * pow(base, exp - half)
pow = traced(pow)
pow = counting(pow)
pow = cached(pow)
print('exp({}, {}) = {}'.format(2, 15, pow(2, 15)))
called pow(2, 15)
called pow(2, 7)
called pow(2, 3)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 8)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
exp(2, 15) = 32768
And like the original higher-order functions, these new functionals can also be used as decorators:
@cached
@counting
@traced
def pow(base, exp):
if exp == 0:
return 1
if exp == 1:
return base
half = exp // 2
return pow(base, half) * pow(base, exp - half)
print('exp({}, {}) = {}'.format(2, 15, pow(2, 15)))
called pow(2, 15)
called pow(2, 7)
called pow(2, 3)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 8)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 4)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
called pow(2, 2)
called pow(2, 1)
called pow(2, 1)
exp(2, 15) = 32768