-- # AD with dual numbers -- -- [AD](https://en.wikipedia.org/wiki/Automatic_differentiation) is -- the term for a family of techniques that compute the derivatives of -- computer programs. One particularly simple technique is -- [forward-mode AD with dual -- numbers](https://blog.demofox.org/2014/12/30/dual-numbers-automatic-differentiation/), -- which has the convenient property that it can be implemented as a -- library in most programming languages. This is in contrast to -- other techniques that require nonlocal code transformations, -- typically integrated into the compiler. The basic idea of -- forward-mode AD is that we write our functions such that "numbers" -- carry not just their normal ("primal") value, but also their -- differentiated value ("tangent"). -- Differentiation is defined on -- [fields](https://en.wikipedia.org/wiki/Field_(mathematics)), so -- first we [define a module type](abstract-data-types.html) that -- describes the field interface. module type field = { -- | The field element type. type t -- | Constructing an element from a float. val f64 : f64 -> t val + : t -> t -> t val - : t -> t -> t val * : t -> t -> t val / : t -> t -> t -- | Additive identity. val zero : t -- | Multiplicative identity. val one : t -- | Additive inverse. val neg : t -> t -- | Multiplicative inverse. val recip : t -> t } -- Plain fields are not enough to define many interesting functions. -- An ordered field is a field that also admits the usual comparison -- operators. module type ordered_field = { include field val ==: t -> t -> bool val <: t -> t -> bool val >: t -> t -> bool val <=: t -> t -> bool val >=: t -> t -> bool val !=: t -> t -> bool } -- We can define modules that implement `ordered_field` from scratch, -- but it is more convenient to define a parametric module that can -- take any module `F` that implements the `real` module type, and -- construct a corresponding `ordered_field`. The `real` module type -- is implemented by the built-in `f32` and `f64` modules. Most of -- the definitions are just forwarding those from the module parameter -- `F`. module mk_field_from_numeric (F: real) : ordered_field with t = F.t = { type t = F.t def f64 = F.f64 def (+) = (F.+) def (-) = (F.-) def (*) = (F.*) def (/) = (F./) def (==) = (F.==) def (<) = (F.<) def (>) = (F.>) def (<=) = (F.<=) def (>=) = (F.>=) def (!=) = (F.!=) def zero = F.i64 0 def one = F.i64 1 def neg = F.neg def recip = F.recip } -- We create a field module where the elements are `f64`: module f64_field = mk_field_from_numeric f64 -- Now we can define the module type of fields of dual numbers. These -- support the usual field operations, as well as extracting the -- "primal" (normal) and "tangent" (differentiated) components of a -- number. module type dual_field = { -- | The element type of the underlying field (the components of the -- dual numbers). type underlying -- | We include all the operations required by ordered fields. The -- 't' will be a dual number. include ordered_field -- | The primal of a dual number if the normal result. val primal : t -> underlying -- | The tangent is, well, the tangent.. val tangent : t -> underlying -- | Construct a dual number with tangent zero. val dual0 : underlying -> t -- | Construct a dual number with tangent one. val dual1 : underlying -> t } -- We now define a parametric module that, given an ordered field, -- construct a ordered field that uses dual numbers for the field -- elements. We keep the actual representation of the dual numbers -- abstract. module mk_dual (F: ordered_field) : (dual_field with underlying = F.t) = { type underlying = F.t -- We represent a dual number as a pair of the "primal" and "tangent" -- parts. type t = (underlying, underlying) def primal ((x, _) : t) = x def tangent ((_, x') : t) = x' def dual0 x : t = (x, F.f64 0) def dual1 x : t = (x, F.f64 1) -- A constant has tangent zero. def f64 x = dual0 (F.f64 x) def zero = f64 0 def one = f64 1 -- Negation is defined in the obvious way. def neg (x,x') = (F.neg x, F.neg x') -- The reciprocal is a little more tricky, but you can look up the -- reciprocal rule in a calculus textbook (or more realistically, on -- Wikipedia). def recip (x,x') = (F.recip x, F.(neg (x'/(x*x)))) -- Then we get to the actual arithmetic operations. These are also -- as you'd expect to find in a textbook. We define subtraction and -- division via the inverse elements, so we have fewer things -- written from scratch. def (x,x') + (y,y') = F.((x + y, x' + y')) def (x,x') * (y,y') = F.((x * y, x' * y + x * y')) def x - y = x + neg y def x / y = x * recip y -- Comparisons are straightforward and use only the primal parts. -- Since we produce booleans here, the result has no tangent. def (x,_) == (y,_) = F.(x == y) def (x,_) < (y,_) = F.(x < y) def (x,_) > (y,_) = F.(x > y) def (x,_) <= (y,_) = F.(x <= y) def (x,_) >= (y,_) = F.(x >= y) def (x,_) != (y,_) = F.(x != y) } -- We instantiate it with the `f64_field` module defined above: module dual_f64 = mk_dual f64_field -- To show off forward-mode AD, we define various functions -- parameterised over the field representation. This lets us evaluate -- them using either normal numbers or dual numbers. In Haskell we'd -- use type classes for this, but in Futhark and other ML languages, -- we use a parametric module yet again. This carries a fair bit of -- overhead, unfortunately. module mk_test (F: ordered_field) = { def test (x: F.t) (y: F.t) = F.((x*x) + (x*y)) } -- We can instantiate this module using both ordinary numbers and dual -- numbers: module test_f64 = mk_test f64_field module test_dual = mk_test dual_f64 -- Ordinary evaluation works: -- -- ``` -- > :t test_f64.test -- test_f64.test : (x: f64) -> (y: f64) -> f64 -- > test_f64.test 1 2 -- 3.0f64 -- ``` -- -- When we want to compute the partial derivative of a function with -- respect to one of its parameters, we pass it a dual number with -- initial tangent *1* for *just that* parameter (using the `dual1` -- function), and make the initial tangent *0* for all the other -- parameters: -- -- ``` -- > dual_f64.tangent (test_dual.test (dual_f64.dual1 1) (dual_f64.dual0 2)) -- 4.0f64 -- > dual_f64.tangent (test_dual.test (dual_f64.dual0 1) (dual_f64.dual1 2)) -- 1.0f64 -- ``` -- -- For a function with *n* inputs, we need to perform *n* evaluations -- to obtain all partial derivatives. This is the main weakness of -- forward-mode AD. -- -- More complex functions work well. Like this implementation of -- square root by naive Newton-Rhapson iteration (note that this is -- not numerically stable): module mk_sqrt (F: ordered_field) = { def abs (x: F.t) = if F.(x < f64 0) then F.neg x else x def sqrt (x: F.t) = let difference = F.f64 0.001 in loop guess = F.f64 1 while F.(abs(guess * guess - x) >= difference) do F.((x/guess + guess)/(f64 2)) } module sqrt_dual = mk_sqrt dual_f64 -- Trying it out in the REPL: -- -- ``` -- > dual_f64.tangent (sqrt_dual.sqrt (dual_f64.dual1 9)) -- 0.16673279002623667f64 -- ``` -- -- In practice, we usually add square roots and similar primitive -- functions to our field interface and directly implement its known -- derivative. But this shows that even something as nasty as a -- `while`-loop works fine with forward-mode AD.