Dex: Multi-step ray tracer
The following is a port of raytrace.dx, which is itself based on this Jax program by Eric Jang. See this blog post for details on the algorithm - this example won’t be getting into those details. We’ve tried to maintain the naming scheme of the original program, but we’ve re-ordered most definitions bring them closer to their first usage.
import "dex-prelude"
Geometry
The original Dex program implements vectors in three-dimensional space with three-element arrays. This is not optimal in Futhark for reasons explained elsewhere, so instead we import the vector module from another example program.
module threed = import "3d-vectors"
And instantiate it with double precision floats.
module vspace = threed.mk_vspace_3d f64
Convenient shorthands for various vector types and operations:
type Vec = vspace.vector
def (*>) = vspace.scale
def (<+>) = (vspace.+)
def (<->) = (vspace.-)
def (<*>) = (vspace.*)
def dot = vspace.dot
def length = vspace.length
def normalize = vspace.normalise
def cross = vspace.cross
def rotateY = flip vspace.rot_y
These vectors are not arrays and we cannot map
them, but we can
define our own mapping function.
def vmap (f: f64 -> f64) (v : Vec) =
{x = f v.x, y = f v.y, z = f v.z}
Dex calls the scaling operator .*
, but that name is not valid in
Futhark. Note that the length
function will shadow the one in
the Futhark prelude.
Dex’s Color
type is from the
plot.dex
file, which we’re not going to implement in its entirety.
Much of the following is just as straighforward in Futhark as in Dex, and so will not have much commentary.
type Color = Vec
type Angle = f64 -- angle in radians
type Distance = f64
type Position = Vec
type Direction = Vec -- Should be normalized.
type BlockHalfWidths = Vec
type Radius = f64
type Radiance = Color
type ObjectGeom
= #Wall Direction Distance
| #Block Position BlockHalfWidths Angle
| #Sphere Position Radius
type Surface
= #Matte Color
| #Mirror
type OrientedSurface = (Direction, Surface)
type Object
= #PassiveObject ObjectGeom Surface-- position, half-width, intensity (assumed to point down)
| #Light Position f64 Radiance
def vec x y z : Vec = {x,y,z}
The signed distance function
This function is the reason why the Futhark implementation is so convoluted, as we will need its derivative.
def sdObject (pos:Position) (obj:Object) : Distance =
match obj
case #PassiveObject geom _ ->
match geom
(case #Wall nor d -> f64.(d + dot nor pos)
case #Block blockPos halfWidths angle ->
let pos' = rotateY (pos <-> blockPos) angle
in length (vmap (f64.max 0) (vmap f64.abs pos' <-> halfWidths))
case #Sphere spherePos r ->
let pos' = pos <-> spherePos
in f64.(max (length pos' - r) (f64 0)))
case #Light squarePos hw _ ->
let pos' = pos <-> squarePos
let halfWidths = {x=hw, y=f64.f64 0.1, z=hw}
in length (vmap (f64.max 0)
(vmap f64.abs pos' <-> halfWidths))
Working with the geometry
Finally we can define the function that computes surface normals at
a given position. This is where we use AD. Specifically, we
define a grad
operator to convert distance functions into normal
functions.
def grad 'a (f: a -> f64) (x: a) : a = vjp f x 1
def calcNormal (obj: Object) (pos: Position) : Direction =
normalize (grad (flip sdObject obj) pos)
Random sampling
Since dex-prelude.fut defines the same random number generation interface as the Dex program, the sampling functions are straightforward ports.
First, a function for generating a number in a range:
def randuniform (lower: f64) (upper: f64) (k: Key) =
let x = rand k
in (lower + x * (upper-lower))
Sample within a square with side length 2*hw:
def sampleSquare (hw: f64) (k: Key) : Position =
let (kx, kz) = splitKey k
let x = randuniform (- hw) hw kx
let z = randuniform (- hw) hw kz
in {x, y=0.0, z}
Sample within a hemisphere in the direction of the normal vector:
def sampleCosineWeightedHemisphere (normal: Vec) (k: Key) : Vec =
let (k1, k2) = splitKey k
let u1 = rand k1
let u2 = rand k2
let uu = normalize (cross normal (vec 0.0 1.1 1.1))
let vv = cross uu normal
let ra = f64.sqrt u2
let rx = ra * f64.cos (2 * f64.pi * u1)
let ry = ra * f64.sin (2 * f64.pi * u1)
let rz = f64.sqrt (1.0 - u2)
let rr = (rx *> uu) <+> (ry *> vv) <+> (rz *> normal)
in normalize rr
Ray marching
The essence of ray marching is simple: move along a given vector and find the first object we collide with. We do this the naive way, by invoking the distance function for every single object. This is not efficient, to put it mildly, but it is easy to implement.
First we need a minimumBy
function. Despite Futhark lacking
Dex’s support for ad-hoc polymorphism, which matters here as we end
up passing in the less-than operator, it doesn’t end up looking
too bad.
def minBy 'a 'o (lt: o -> o -> bool) (f: a->o) (x:a) (y:a) : a =
if f x `lt` f y then x else y
def minimumBy [n] 'a 'o (lt: o -> o -> bool) (f: a->o) (xs: [n]a) : a =
0] xs reduce (minBy lt f) xs[
A scene is a collection of objects.
type Scene [n] = [n]Object
We can now define a function for finding the closest object, given a position.
def sdScene (objs: Scene []) (pos: Position) : (Object, Distance) =
let i =
minimumBy (<) (\i -> sdObject pos objs[i])
(indices objs)in (objs[i], sdObject pos objs[i])
When we ray march, we either collide with nothing and disappear into the aether, hit a light, or hit an object, in which case we also produce information about the surface, including the surface normal.
type Ray = (Position, Direction)
type RayMarchResult
= #HitObj Ray OrientedSurface
| #HitLight Radiance | #HitNothing
The ray marching function is defined with the iter
function in
Dex. In Futhark, a while
loop can express it in an equally
natural way.
def positiveProjection (x: Vec) (y: Vec) =
0
dot x y >
def raymarch [n] (scene:Scene [n]) (ray:Ray) : RayMarchResult =
let max_iters = 100
let tol = 0.01
let (rayOrigin, rayDir) = ray
let (_, _, res) =
loop (i, rayLength, _) = (0, 10 * tol, #HitNothing)
while i < max_iters do
let rayPos = rayOrigin <+> (rayLength *> rayDir)
let (obj, d) = sdScene scene rayPos
let dNew = rayLength + 0.9 * d
in if d >= tol
then (i+1, dNew, #HitNothing)
else
let surfNorm = calcNormal obj rayPos
in if positiveProjection rayDir surfNorm
then (i+1, dNew, #HitNothing)
else (max_iters,
dNew,match obj
case #PassiveObject _ surf ->
#HitObj (rayPos, rayDir) (surfNorm, surf)case #Light _ _ radiance ->
#HitLight radiance)in res
Light sampling
These definitions are pretty straightforward, and similar to those in Dex.
To figure out whether a light is shining on us along a given path, we just march along that path and see if we hit a light. If we do, the radiance is that light.
def rayDirectRadiance [n] (scene: Scene [n]) (ray: Ray) : Radiance =
match raymarch scene ray
case #HitLight intensity -> intensity
case #HitNothing -> vec 0 0 0
case #HitObj _ _ -> vec 0 0 0
Shorthands for vectors along the cardinal axes will become useful in the following.
def xHat : Vec = vec 1 0 0
def yHat : Vec = vec 0 1 0
def zHat : Vec = vec 0 0 1
def relu (x: f64) = f64.max x 0
def probReflection ((nor, surf): OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : f64 =
match surf
case #Matte _ -> relu (dot nor outRayDir)
case #Mirror -> 0
def directionAndLength (x: Vec) =
(normalize x, length x)
The following function determines how much light is shining on a given point. It’s very similar to the Dex function. Dex uses an explicit accumulator that is updated via effects, but I don’t think it makes it much more readable in this case.
def sampleLightRadiance [n] (scene: Scene [n])
(osurf: OrientedSurface)
(inRay: Ray)
(k: Key) : Radiance =let (surfNor, _) = osurf
let (rayPos, _) = inRay
in loop radiance = vec 0 0 0 for obj in scene do
match obj
case #PassiveObject _ _ -> radiance
case #Light lightPos hw _ ->
let (dirToLight, distToLight) =
directionAndLength (lightPos <+> sampleSquare hw k <-> rayPos)in if ! (positiveProjection dirToLight surfNor)
then radiance -- light on the far side of current surface
else
let fracSolidAngle = relu (dot dirToLight yHat) *
sq hw / (f64.pi * sq distToLight)let outRay = (rayPos, dirToLight)
let coeff = fracSolidAngle * probReflection osurf inRay outRay
in radiance <+> (coeff *> rayDirectRadiance scene outRay)
Tracing
Almost done. Everything here is very similar to the Dex code.
type Filter = Color
def applyFilter (filter:Filter) (radiance:Radiance) : Radiance =
filter <*> radiance
def surfaceFilter (filter:Filter) (surf:Surface) : Filter =
match surf
case #Matte color -> filter <*> color
case #Mirror -> filter
def sampleReflection ((nor, surf): OrientedSurface) ((pos, dir): Ray) (k: Key) : Ray =
let newDir = match surf
case #Matte _ -> sampleCosineWeightedHemisphere nor k
case #Mirror -> dir <-> 2 * dot dir nor *> nor
in (pos, newDir)
We’re excluding Dex’s shareSeed
field from the tracing
parameters, since it seems mostly relevant for showing the impact
of poor seeding on the convergence of the algorithm.
type Params = { numSamples : i32,
maxBounces : i32 }
def trace [n] (params: Params) (scene: Scene [n]) (init_ray: Ray) (k: Key) : Color =
2) <|
(.loop
0, (vec 1 1 1), (vec 0 0 0), init_ray)
(i, filter, radiance, ray) = (while i < params.maxBounces do
match raymarch scene ray
case #HitNothing ->
(params.maxBounces, filter, radiance, ray)case #HitLight intensity ->
if i == 0
then (params.maxBounces, filter, intensity, ray)
else (params.maxBounces, filter, radiance, ray)
case #HitObj incidentRay osurf ->
let (k1, k2) = splitKey (hash k i)
let lightRadiance = sampleLightRadiance scene osurf incidentRay k1
let outRayHemisphere = sampleReflection osurf incidentRay k2
let newFilter = surfaceFilter filter osurf.1
let newRadiance = radiance <+> applyFilter newFilter lightRadiance
in (i+1, newFilter, newRadiance, outRayHemisphere)
Camera
The camera controls how the initial rays are sent into the scene.
type Camera =
{ numPix : i64,
pos : Position,
halfWidth : f64, sensorDist : f64 }
And now we’re ready to produce the n*n array of initial rays and RNG states.
def cameraRays (n: i64) (camera: Camera) : [n][n](Ray, Key) =
let halfWidth = camera.halfWidth
let pixHalfWidth = halfWidth / f64.i64 n
let ys = reverse (linspace n (-halfWidth) halfWidth)
let xs = linspace n (-halfWidth) halfWidth
let kss = tabulate_2d n n (\i j -> newKey (i32.i64 (1+i*n+j)))
let rayForPixel y x k =
let dx = randuniform (-pixHalfWidth) pixHalfWidth k
let dy = randuniform (-pixHalfWidth) pixHalfWidth k
in ((camera.pos,
normalize (vec (x + dx) (y + dy) (-(camera.sensorDist)))),
k)in map2 (\y ks -> map2 (rayForPixel y) xs ks) ys kss
Most ray tracers perform multiple samples per pixel and take the
average. The Dex program defines the sampleAveraged
function to
be (potentially) parallel. We know the parallelism would never be
exploited anyway, so we define it as a sequential loop, mostly to
make the random number state management simpler.
def sampleAveraged (sample: Key -> Vec) (n: i32) (k: Key) : Vec =
loop acc = vec 0 0 0 for i < n do
(
(acc <+> sample (ixkey k (i64.i32 i))))1/f64.i32 n) *>) |> ((
The Dex implementation uses an unusual relative colorisation strategy, where the final color values are all divided by the average intensity. This is likely so we won’t have to fiddle with the light intensities to avoid very bright or very dark images.
def meanIntensity image =
3) |> mean image |> flatten |> map (\{x,y,z} -> (x+y+z)/
Smile for the camera!
def takePicture [m] (params: Params) (scene: Scene [m]) (camera: Camera) =
let n = camera.numPix
let rays = cameraRays n camera
let sample (r, k) =
sampleAveraged (trace params scene r) params.numSamples klet image = map (map sample) rays
let mean = meanIntensity image
in map (map ((1/mean)*>)) image
Scene definition
The only odd thing here is that apparently our vectors rotate opposite of the way Dex does it, so I had to give the block a rotation of -0.5 instead of 0.5.
def lightColor = vec 0.2 0.2 0.2
def leftWallColor = 1.5 *> vec 0.611 0.0555 0.062
def rightWallColor = 1.5 *> vec 0.117 0.4125 0.115
def whiteWallColor = (1/255) *> vec 255.0 239.0 196.0
def blockColor = (1/255) *> vec 200.0 200.0 255.0
def neg = ((-1)*>)
def theScene : Scene [] =
1.9 *> yHat) 0.5 lightColor
[ #Light (2.0) (#Matte leftWallColor)
, #PassiveObject (#Wall xHat 2.0) (#Matte rightWallColor)
, #PassiveObject (#Wall (neg xHat) 2.0) (#Matte whiteWallColor)
, #PassiveObject (#Wall yHat 2.0) (#Matte whiteWallColor)
, #PassiveObject (#Wall (neg yHat) 2.0) (#Matte whiteWallColor)
, #PassiveObject (#Wall zHat 1.0 (-1.6) 1.2) (vec 0.6 0.8 0.6) (-0.5)) (#Matte blockColor)
, #PassiveObject (#Block (vec -1.0) (-1.2) 0.2) 0.8) (#Matte (0.7 *> whiteWallColor))
, #PassiveObject (#Sphere (vec (2 2 (-2)) 1.5) #Mirror
, #PassiveObject (#Sphere (vec
]
def defaultParams : Params = { numSamples = 50
10 }
, maxBounces =
def defaultCamera : Camera = { numPix = 250
10.0 *> zHat
, pos = 0.3
, halfWidth = 1.0 } , sensorDist =
Entry point
This is not part of the Dex program, but is needed if we want Futhark to produce anything for the outside world.
First, a function to convert floating-point colors to RGB-packed colors with 8 bits per channel. Note that we have to cap each channel, since there is no guarantee that the color components produced by the ray tracer cannot exceed 1.0.
def pix (c: Color) =
255 (c.x * 255)) << 16) |
(u32.f64 (f64.min 255 (c.y * 255)) << 8) |
(u32.f64 (f64.min 255 (c.z * 255)))
(u32.f64 (f64.min
def main n =
with numPix = n)
takePicture defaultParams theScene (defaultCamera |> map (map pix)
> :img main 500i64