Skip to content

Random.map without stack overflow #289

@TysonMN

Description

@TysonMN

This issue is a spin off of #238 (comment)

Recall that the type 'a -> 'b is (covariant) functor in 'b. The map function for this functor is function composition (which is either >> or << with only the order of inputs chanding). Our Random<'c> type is a (covariant) functor in 'c because it is just an wrapper around the type 'a -> 'b -> 'c, which is also a (covarant) functor in 'c. We can simplify things by uncurrying to get back to a function of the form 'a -> 'b.

The naive way to implement map for the (covariant) functor 'a -> 'b can overflow the stack. Specifically, the following test fails.

[<Fact>]
let ``Does function composition overflow the stack? Answer: Yes`` () =
  let n = 100_000
  let f =
    id
    |> List.replicate n
    |> List.fold (>>) id
  f ()

Here is one way to avoid overflowing the stack. I will be the first to say that this is not elegant.

type Fun<'a, 'b> =
  { In: 'a -> obj
    FuncsBefore: (obj -> obj) list
    FuncsAfter: (obj -> obj) list
    Out: obj -> 'b }


module Fun =

  let id<'a> =
    { In = box<'a>
      FuncsBefore = []
      FuncsAfter = []
      Out = unbox<'a> }

  let evaluate f a =
    (f.FuncsBefore @ List.rev f.FuncsAfter)
    |> List.fold (fun a f -> f a) (f.In a)
    |> f.Out

  let composeBefore (g: 'a -> 'b) (f: Fun<'b, 'c>) : Fun<'a, 'c> =
    { In = box<'a>
      FuncsBefore = (unbox<'a> >> g >> box<'b>) :: f.FuncsBefore
      FuncsAfter = f.FuncsAfter
      Out = f.Out }

  let composeAfter (g: 'b -> 'c) (f: Fun<'a, 'b>) : Fun<'a, 'c> =
    { In = f.In
      FuncsBefore = f.FuncsBefore
      FuncsAfter = (unbox<'b> >> g >> box<'c>) :: f.FuncsAfter
      Out = unbox<'c> }


[<Fact>]
let ``Custom function composition`` () =
  let flip f b a = f a b

  let n = 1_000_000
  let f =
    (+) 1
    |> List.replicate n
    |> List.fold (flip Fun.composeAfter) Fun.id
    |> Fun.evaluate
  let actual = f 0
  Assert.Equal(n, actual)

  let n = 1_000_000
  let f =
    (+) 1
    |> List.replicate n
    |> List.fold (flip Fun.composeBefore) Fun.id
    |> Fun.evaluate
  let actual = f 0
  Assert.Equal(n, actual)

  let f =
    Fun.id<double>
    |> Fun.composeAfter (sprintf "%f")
    |> Fun.composeAfter (fun s -> s |> String.length)
    |> Fun.composeAfter (fun n -> n % 2 = 0)
    |> Fun.evaluate
  let b = f 3.141592
  Assert.True(b)

  let f =
    Fun.id<bool>
    |> Fun.composeBefore (fun n -> n % 2 = 0)
    |> Fun.composeBefore (fun s -> s |> String.length)
    |> Fun.composeBefore (sprintf "%f")
    |> Fun.evaluate
  let b = f 3.141592
  Assert.True(b)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions