Fork me on GitHub
Source file: dual-numbers.fut

AD with dual numbers

AD is the term for a family of techniques that compute the derivatives of computer programs. One particularly simple technique is forward-mode AD with dual numbers, which has the convenient property that it can be implemented as a library in most programming languages. This is in contrast to other techniques that require nonlocal code transformations, typically integrated into the compiler. The basic idea of forward-mode AD is that we write our functions such that “numbers” carry not just their normal (“primal”) value, but also their differentiated value (“tangent”).

Differentiation is defined on fields, so first we define a module type that describes the field interface.

module type field = {
  -- | The field element type.
  type t

  -- | Constructing an element from a float.
  val f64 : f64 -> t

  val + : t -> t -> t
  val - : t -> t -> t
  val * : t -> t -> t
  val / : t -> t -> t

  -- | Additive identity.
  val zero : t

  -- | Multiplicative identity.
  val one : t

  -- | Additive inverse.
  val neg : t -> t

  -- | Multiplicative inverse.
  val recip : t -> t
}

Plain fields are not enough to define many interesting functions. An ordered field is a field that also admits the usual comparison operators.

module type ordered_field = {
  include field

  val ==: t -> t -> bool
  val <: t -> t -> bool
  val >: t -> t -> bool
  val <=: t -> t -> bool
  val >=: t -> t -> bool
  val !=: t -> t -> bool
}

We can define modules that implement ordered_field from scratch, but it is more convenient to define a parametric module that can take any module F that implements the real module type, and construct a corresponding ordered_field. The real module type is implemented by the built-in f32 and f64 modules. Most of the definitions are just forwarding those from the module parameter F.

module mk_field_from_numeric (F: real) : ordered_field with t = F.t = {
  type t = F.t
  def f64 = F.f64
  def (+) = (F.+)
  def (-) = (F.-)
  def (*) = (F.*)
  def (/) = (F./)
  def (==) = (F.==)
  def (<) = (F.<)
  def (>) = (F.>)
  def (<=) = (F.<=)
  def (>=) = (F.>=)
  def (!=) = (F.!=)

  def zero = F.i64 0
  def one = F.i64 1
  def neg = F.neg
  def recip = F.recip
}

We create a field module where the elements are f64:

module f64_field = mk_field_from_numeric f64

Now we can define the module type of fields of dual numbers. These support the usual field operations, as well as extracting the “primal” (normal) and “tangent” (differentiated) components of a number.

module type dual_field = {
  -- | The element type of the underlying field (the components of the
  -- dual numbers).
  type underlying

  -- | We include all the operations required by ordered fields.  The
  -- 't' will be a dual number.
  include ordered_field

  -- | The primal of a dual number if the normal result.
  val primal : t -> underlying

  -- | The tangent is, well, the tangent..
  val tangent : t -> underlying

  -- | Construct a dual number with tangent zero.
  val dual0 : underlying -> t

  -- | Construct a dual number with tangent one.
  val dual1 : underlying -> t
}

We now define a parametric module that, given an ordered field, construct a ordered field that uses dual numbers for the field elements. We keep the actual representation of the dual numbers abstract.

module mk_dual (F: ordered_field) : (dual_field with underlying = F.t) = {
  type underlying = F.t

We represent a dual number as a pair of the “primal” and “tangent” parts.

  type t = (underlying, underlying)
  def primal ((x, _) : t) = x
  def tangent ((_, x') : t) = x'
  def dual0 x : t = (x, F.f64 0)
  def dual1 x : t = (x, F.f64 1)

A constant has tangent zero.

  def f64 x = dual0 (F.f64 x)
  def zero = f64 0
  def one = f64 1

Negation is defined in the obvious way.

  def neg (x,x') = (F.neg x, F.neg x')

The reciprocal is a little more tricky, but you can look up the reciprocal rule in a calculus textbook (or more realistically, on Wikipedia).

  def recip (x,x') = (F.recip x, F.(neg (x'/(x*x))))

Then we get to the actual arithmetic operations. These are also as you’d expect to find in a textbook. We define subtraction and division via the inverse elements, so we have fewer things written from scratch.

  def (x,x') + (y,y') = F.((x + y, x' + y'))
  def (x,x') * (y,y') = F.((x * y, x' * y + x * y'))
  def x - y = x + neg y
  def x / y = x * recip y

Comparisons are straightforward and use only the primal parts. Since we produce booleans here, the result has no tangent.

  def (x,_) == (y,_) = F.(x == y)
  def (x,_) < (y,_) = F.(x < y)
  def (x,_) > (y,_) = F.(x > y)
  def (x,_) <= (y,_) = F.(x <= y)
  def (x,_) >= (y,_) = F.(x >= y)
  def (x,_) != (y,_) = F.(x != y)
}

We instantiate it with the f64_field module defined above:

module dual_f64 = mk_dual f64_field

To show off forward-mode AD, we define various functions parameterised over the field representation. This lets us evaluate them using either normal numbers or dual numbers. In Haskell we’d use type classes for this, but in Futhark and other ML languages, we use a parametric module yet again. This carries a fair bit of overhead, unfortunately.

module mk_test (F: ordered_field) = {
  def test (x: F.t) (y: F.t) =
    F.((x*x) + (x*y))
}

We can instantiate this module using both ordinary numbers and dual numbers:

module test_f64 = mk_test f64_field
module test_dual = mk_test dual_f64

Ordinary evaluation works:

> :t test_f64.test
test_f64.test : (x: f64) -> (y: f64) -> f64
> test_f64.test 1 2
3.0f64

When we want to compute the partial derivative of a function with respect to one of its parameters, we pass it a dual number with initial tangent 1 for just that parameter (using the dual1 function), and make the initial tangent 0 for all the other parameters:

> dual_f64.tangent (test_dual.test (dual_f64.dual1 1) (dual_f64.dual0 2))
4.0f64
> dual_f64.tangent (test_dual.test (dual_f64.dual0 1) (dual_f64.dual1 2))
1.0f64

For a function with n inputs, we need to perform n evaluations to obtain all partial derivatives. This is the main weakness of forward-mode AD.

More complex functions work well. Like this implementation of square root by naive Newton-Rhapson iteration (note that this is not numerically stable):

module mk_sqrt (F: ordered_field) = {

  def abs (x: F.t) =
    if F.(x < f64 0)
    then F.neg x
    else x

  def sqrt (x: F.t) =
    let difference = F.f64 0.001
    in loop guess = F.f64 1
       while F.(abs(guess * guess - x) >= difference) do
         F.((x/guess + guess)/(f64 2))
}

module sqrt_dual = mk_sqrt dual_f64

Trying it out in the REPL:

> dual_f64.tangent (sqrt_dual.sqrt (dual_f64.dual1 9))
0.16673279002623667f64

In practice, we usually add square roots and similar primitive functions to our field interface and directly implement its known derivative. But this shows that even something as nasty as a while-loop works fine with forward-mode AD.