In this post we will see how to extend reverse mode automatic differentiation to a language with first class function types, function application and -abstraction. This method is not new, but we will give a new derivation of it by showing how it arises universally from noticing that the category of “additive lenses” is cartesian closed. In the end we will see that this idea sounds like it should revolutionise machine learning, but then doesn’t. Everything in this post is joint work with Bruno Gavranović, who is now my colleague at Glaive.
Additive containers and additive lenses
We need to begin with some background categorical cybernetics, from the paper Categorical foundations of gradient-based learning and Bruno’s PhD thesis.
The categorical semantics of types in a differentiable language are additive containers, which is what we call set-indexed families of commutative monoids. (The reason we use an unnecessarily fancy word is because set-indexing is stronger than we sometimes want or need; categorical cybernetics has lots of machinery for working with related objects so we use terminology that unifies them.) The idea is the semantics of a type is (1) a set of values that the type can have in the forwards pass, and (2) for each of those values, a commutative monoid of derivatives it can have at that value in the backwards pass.
The semantics of a function is an additive lens between the additive containers and . This consists of two things: (1) an ordinary function between sets describing the semantics of the forwards pass, and (2) an -indexed family of monoid homomorphisms from the commutative monoid to . That is to say, these are dependent lenses whose backwards passes are additive in their second input. The chain rule for lens composition exactly describes the denotational semantics of how the forwards and backwards passes interact in autodiff.
In the category of ordinary containers and lenses, without the monoid structure, we have both a cartesian product and a tensor product. The cartesian product of containers (ie. it satisfies the universal property) is given by product in the forwards pass and coproduct in the backwards pass: . (To be more specific, the backwards part is the -indexed family of sets .) This happens essentially because the backwards pass is backwards, so the universal property of a product turns into a coproduct. There is also a tensor product given by product in the forwards and backwards passes: . For the majority of applications in categorical cybernetics the tensor product is the useful one, but it doesn’t satisfy a straightforward universal property.1
But in additive containers something different happens: the thing that satisfies the universal property of a cartesian product switches to . The structural reason for this is that the category of commutative monoids has finite biproducts: the monoid of pairs satisfies the universal properties of both a product and a coproduct. It’s worth thinking for a moment about why this is: given maps and , we get a copairing given by ; and injections are given by setting one of the values to be zero.
This gives the correct semantics to copying in autodiff: if we use a forwards pass value multiple times, we get a different derivative coming from each use site and the right thing to do is to add those derivatives.
Additive containers are cartesian closed
The category of ordinary containers and ordinary lenses is monoidal closed for the tensor product.2 Given containers and , the hom-container has as its forwards set the set of all lenses , and as the backwards set indexed by a specific lens it has the set , whose elements are pairs where and . Writing this whole container as , this gives us the universal property that lenses are in natural bijection with lenses .
Meanwhile in the world of additive containers, the thing that was previously just the tensor product is now the cartesian product, so we might wonder whether additive containers are cartesian closed. And they are! The forwards set of is, as expected, the set of additive lenses of that type. (I am going to continue using the symbol for the internal hom of additive containers, to carefully distinguish it from functions in the underlying category of sets.) But there is a very interesting twist in the backwards part.
The backwards monoid is still the thing that we would write . It is the coproduct in the category of the -indexed family of monoids . So we need to know what coproducts of commutative monoids are like. The category has finite biproducts, but for infinite families products and coproducts no longer coincide. The product remains the obvious thing, the monoid whose elements are dependent functions - this is called a direct product. The coproduct is a slightly different construction called a direct sum: an element of is a dependent function satisfying the property that all except finitely many values are zero. That is, it is a finite support dependent function. (The monoid structure is pointwise addition of functions.) The underlying set of is something I would write as .
It’s worth also spending a moment thinking about why the coproduct has to be something like this. Suppose we have an infinite family of monoids , and for each of them we have a map out . We need to be able to take their “cotupling” , which we know from the binary case should be to evaluate each and then add the results in . But doesn’t admit any kind of infinitary addition in general, so we can only do this if we know that only finitely many of the are nonzero.
To summarise: the category of additive containers and additive lenses is cartesian closed, where the hom-container has as its forwards set the set of additive lenses, and where the backwards commutative monoid over an additive lens is the coproduct of commutative monoids , whose underlying set is the set of finite support functions .
Representing the direct sum
How do we represent the set of finite support functions in an implementation? The simplest idea would be to just use a function with an informal datatype invariant. But it turns out that we need the copairing operation to be implementable, which is not possible using this representation. The thing we are actually going to use is a finite list of pairs of inputs and outputs: we represent the set as .
Of course lists of pairs are not actually the same as finite support functions. For one thing the order doesn’t matter, so we need to quotient out list permutation. Our next thought might be that we should have a datatype invariant that each input value can only appear once in the list, so it is never ambiguous what the output should be. But there turns out to be a better way: we say that the value of is the sum of all of the such that appears in the list. (And if never appears then the sum is .) This amounts to defining a quotient type generated by 3 classes of identifications:
If we happened to be working in a language with higher inductive types we could tell the machine this, but as it transpires that would be overkill: it turns out that the only operation we will ever need to define out of a direct sum is to add together all of its entries, and that plainly respects all of these equations. (In the case of the first one, that’s exactly where we use the fact that all of our monoids are commutative.)
I wrote earlier than the commutative monoid structure on finite support functions is pointwise, implicitly relying on the fact that pointwise addition of finite support functions is finite support, and the unit of pointwise addition is the constant zero function which is also finite support. The reason that we want a representation where is the sum of all with in the list, rather than the arguably more obvious approach of requiring that each appears in the list at most once, is that it allows us to do a very slick magic trick: pointwise addition of functions is represented by list concatenation. Normally list concatenation gives results in noncommutative monoids, but it is commutative modulo our first quotient equation. This is a slight subtlety: our data representation uses monoids that are technically not commutative, but everything is “heredetarily commutative” in the appropriate sense that whenever we evaluate down to actual numbers everything will work out. There is a sense in which all of our additions are “suspended”, and we will see later that they are “resumed” when the backwards pass gets to the point where we created the function, namely -abstraction.
The 3 steps from function types to direct sums to lists of pairs is a combination of ideas that I have not seen together in one place. But every 2 out of 3 appear somewhere, and the connection between all 3 is no doubt folklore known to some, although we haven’t been able to find actual evidence of that. So this post might be the first universal derivation of the list-of-pairs method, in the sense that everything follows from structural principles and essentially no choices are made.
The connection between function types and lists of pairs seems to be credited to the paper Lambda the ultimate backpropagator, but I find that paper quite impenetrable, so I don’t know exactly how much of it appears there. It also appears clearly for example in The differentiable curry. By far the closest I have seen is the paper CHAD for expressive total languages, which derives cartesian closure for additive lenses appears up to the infinitary direct sum. (Literally the same category: they refer to the category of additive lenses by its less catchy standard name, .) But then they never take the final step to lists of pairs, or more importantly recognise that they may have rederived a known method but with a universal derivation. However, the paper has accompanying Haskell code and the list of pairs construction does appear there if you dig deeply enough.
Abstraction and application
Let’s go through the semantics of -abstraction and function application, which introduce and eliminate function types respectively. To do this for real we need the multicategorical version which tracks how a term in a programming language is a function of all of the variables it refers to; but to keep things simple let’s just do the version for categories, which is essentially the same.
Function application is an additive lens of type . Its forwards pass takes an additive lens and a value to the value . Its backwards pass takes an additive lens , a value and a gradient , and we must produce a gradient in and a function gradient in . Of course the gradient is .
The interesting part is the function type gradient, which is a list of pairs . The thing of that type that we can make is the singleton list . That list represents the function that takes to and every other value to . This might sound useless, but usually the reason we would have a function is because we want to call it many times, and if we do that then the semantics of copying the function value will, in the backwards pass, concatenate all of these singletons together. This will end up tracking the total gradient from every call to each possible input of the function; since a function can only be called finitely many times in any particular run of a program this will always be a finite amount of data.
That covers function application. Function types are made using -abstraction, an operation that takes an additive lens of type (where , and are additive containers), and turns it into an additive lens , as through we are abstracting a variable of type . So the data we have to work with is a fowards pass and a backwards pass , which we will split into and .
The forwards pass of inputs a value and must output an additive lens . The forwards pass of that is obviously given by , and its backwards pass is . (The fact that our forwards pass has a backwards pass is, of course, because we’re doing something higher-order.)
The backwards pass of takes an input and a list of pairs , and it must produce a gradient in . What we need to do is to take all of the values , which all live in the commuative monoid , and add them together. This is an example of cotupling for a direct sum.
In ordinary first-order autodiff, copying values leads to adding gradients. There is a sense that function types cause reverse copying to be suspended, carrying around a list of all gradients instead of adding them. And then -abstraction is where the suspended operation is resumed and the values are actually added.
There is no free lunch
Several years ago, I was talking with Mario Alvarez-Picallo about the question of differentiating function types, and he told me a piece of wisdom. He explained that if differentiating function types worked the best possible way, then for any supervised learning problem to learn a function , we could choose our parameter space to be the hom-type and our architecture to be . That is, the application lens is a universal architecture and the hom-type is a universal parameter space, and in this way the entire science and art of designing deep learning architectures would be entirely obsolete. This would not be so much a free lunch as winning free lunches every day for the rest of our life. But of course, this is not what happens.
The catch is that we have only talked about backprop but not about gradient descent. A deep learning architecture consists of a lens of the shape , where is the parameter space. A simplified view of training (ignoring batching and various other complications) works like this: we take an example pair from our dataset, and the current parameters . We have a loss function that takes the predicted output and the target output and gives us a gradient , which we then inject into the backwards pass which gives us an input gradient and a parameter gradient . The input gradient is thrown away (its work is done by the time it reaches us), but the parameter is sent to the optimiser. The most straightforward optimiser is gradient descent: it mutates the current parameters to , where is a learning rate.3
This makes sense when our parameter space is some , whose tangent space is also , so it forms the additive container . (Technically, the first is the set of points, and the second is the constant function that takes every point in the set to the commutative monoid of vectors with addition.) What if our parameter space is instead a space of functions , whose tangent space is the set of finite support functions from points in to gradients in ?
It is technically possible to do gradient descent on these things directly, if the current value of our parameters is always a finite support function then we can add the gradients pointwise. This more or less amounts to an architecture that by design will memorise its training data - it can only overfit.
There is something much smarter we can do, which is to replace our optimiser with an entire other supervised learning architecture. We parameterise the function space using a choice of architecture and parameter space , and when we receive a function gradient - that is a list of pairs - we do our own run of supervised learning on it to optimise our own parameters . And so our free lunch has evaporated in front of our eyes, and we are back to doing ordinary deep learning with extra steps.
I find it plausible that this idea could still have useful applications, in the same way that function types themselves are still useful even though we could have just worked in a Turing-complete language without functions instead. For example, it would allow us to make use of all of the basic functional programming design patterns like mapping and filtering when implementing our architectures, letting the underlying autodiff framework figure out the right thing to do instead of having to unroll everything manually.
How practical this actually is right now depends heavily on how easy it would be to hack this method on top of an existing autodiff framework, which is something we haven’t thought much about yet. My best guess is that making it performant would really require a whole new compute kernel that is able to handle reducing (ie. recursively folding over) lists of suspended gradients, and that this is possible but best left to somebody who knows what they are doing with kernel programming and takes on the task of developing a serious 21st-century autodiff framework that isn’t balanced on top of ancient Fortran code and a programming paradigm that is 50 years out of date.
The thing that currently excites me the most about this is that when we have function types we can simulate continuation passing style using a continuation monad, and so a language with function types can be used as an intermediate compiler target for a language with complex control flow features like throw/catch. What would autodiff through throw/catch look like? I have no idea, but it could be interesting to find out.
That is not to say the tensor product can’t be given any universal property at all. My favourite is to characterise it as the fibrewise opposite of a cartesian product.↩︎
It happens to also be cartesian closed, in an extremely non-obvious way, see the paper Higher order containers. That structure plays no role for us.↩︎
There is some interesting differential geometry going on here that I only partially understand.↩︎