Dmitry Leskov
 

The Finite Laziness of Scala Streams

The official Haskell wiki discusses literally dozens of ways to produce a list of prime numbers, but the one that caught my attention the other day is not listed there. I found that neat piece of code in the Epilogue section of the paper The Genuine Sieve of Eratosthenes by Melissa E. O’Neill, who attributed it to Richard Bird:

primes = 2:([3..] 'minus' composites)
  where
    composites = union [multiples p | p <- primes]

multiples n = map(n*)[n..]

(x:xs) 'minus' (y:ys) | x < y  = x:(xs 'minus' (y:ys))
                      | x == y = xs 'minus' ys
                      | x > y  = (x:xs) 'minus' ys

union = foldr merge []
  where
   merge (x:xs) ys = x:merge' xs ys
   merge'(x:xs) (y:ys) |x <  y = x:merge' xs (y:ys)
                       |x == y = x:merge' xs ys
                       |x >  y = y:merge' (x:xs) ys

What fascinated me most in the above code is that the union function is essentially a fold of a recursively defined infinite list of infinite lists. I wondered if it might be possible to express the same concept in Scala using streams. The answer turned out to be positive, but it took me a while to figure it out.

Here is the naive direct rewrite. It fails with a stack overflow. Can you spot the problem?

def primes = 2 #:: minus(Stream.from(3), composites)

def composites: Stream[Int] =
  union(for (p <- primes) yield multiples(p))

def multiples(n: Int) = Stream.from(n) map (n * _)

val minus: (Stream[Int], Stream[Int]) => Stream[Int] = {
  case (x#::xs, y#::ys) =>
    if      (x <  y)   x #:: minus(xs, y#::ys)
    else if (x == y)   minus(xs, ys)
    else  /* x >  y */ minus(x#::xs, ys)
}

def union(ss: Stream[Stream[Int]]): Stream[Int] =
  ss.foldRight(Stream[Int]())(merge)

val merge: (Stream[Int], Stream[Int]) => Stream[Int] = {
  case (x#::xs, ys) => x #:: merge1(xs, ys)
}

val merge1: (Stream[Int], Stream[Int]) => Stream[Int] = {
  case (x#::xs, y#::ys) =>
    if      (x <  y)   x #:: merge1(xs, y#::ys)
    else if (x == y)   x #:: merge1(xs, ys)
    else  /* x >  y */ y #:: merge1(x#::xs, ys)
}

println((primes take 20).toList)

(I opted to define some of the functions using val just for conciseness – pattern matching anonymous functions take less space.)

There are actually two problems:

First, functions in Haskell are non-strict by default. “Non-strict” and “lazy” are not the same thing, but anyway, in terms of Scala that would roughly mean that all function parameters are by-name unless otherwise specified, which, as you know, is not the case. This becomes critical in the op parameter of the Stream.foldRight method:

def foldRight[B](z: B)(op: (A, B) => B): B

As you may see, the second parameter of the merge function, once matched, only appears in its body to the right of the #:: operator:

val merge: (Stream[Int], Stream[Int]) => Stream[Int] = {
  case (x#::xs, ys) => x #:: merge1(xs, ys)
}

and the second parameter of that operator is a by-name parameter, i.e. it is only evaluated when it is actually used. That is essentially what makes the Scala Stream class lazy. Now, if you trace the execution of primes, you would notice that the head of the stream that a merge call returns is required to evaluate the second parameter of that very call, so passing the latter by value results in merge being called recursively with the same parameters, which inevitably leads to a stack overflow.

The solution is to write a lazier version of foldRight:

def union(ss: Stream[Stream[Int]]): Stream[Int] =
  lazyFoldRight(Stream[Int]())(merge)(ss)

def lazyFoldRight[A, B](z: B)(op: (A, => B) => B)(xs: Stream[A]): B =
  if (xs.isEmpty) z
  else op(xs.head, lazyFoldRight(z)(op)(xs.tail))

val merge: (Stream[Int], => Stream[Int]) => Stream[Int] = {
  case (x#::xs, ys) => x #:: merge1(xs, ys)
}

As far as discussing the second problem goes, I would have started with stating that pattern matching in Haskell is non-strict (just like everything else in Haskell), if there was not also, I kid you not, lazy pattern matching in Haskell. So pattern matching is actually strict by default, but it only needs to be strict to the extent required to accept or reject the match. For example, for an expression to match the x:xs pattern, it must be of a list type and eventually evaluate to a cons cell. There is no need to evaluate parts of that cons cell at the time of the match. Unless the strictness analyzer determines that x and/or xs are always required by the outer expression, they get extracted from the cons cell and passed around as thunks – values yet to be evaluated. The equivalent Scala constructs would be parameterless closures () => s.headand () => s.tail.

However, in Scala pattern matching is fully strict. If s is a stream, s match { case x#::xs => ...} is translated into a call of #::unapply() that evaluates both s.head and s.tail. But if you look at e.g. the merge1 function in the above naive rewrite, you would again notice that pattern variables xs and ys appear only in expressions to the right of the #:: operator:

val merge1: (Stream[Int], Stream[Int]) => Stream[Int] = {
  case (x#::xs, y#::ys) =>
    if      (x <  y)   x #:: merge1(xs, y#::ys)
    else if (x == y)   x #:: merge1(xs, ys)
    else  /* x >  y */ y #:: merge1(x#::xs, ys)
}

Simulating lazy matching with closures fixes the second problem, but makes the code less readable:

def merge1(s: Stream[Int], t: Stream[Int]): Stream[Int] = {
  val (x, y) = (s.head, t.head)
  val (xs, ys) = (() => s.tail, () => t.tail)
  if      (x <  y)   x #:: merge1(xs(), y #:: ys())
  else if (x == y)   x #:: merge1(xs(), ys())
  else  /* x >  y */ y #:: merge1(x #:: xs(), ys())
}

so in the end I’ve dropped pattern matching altogether:

def merge1(xs: Stream[Int], ys: Stream[Int]): Stream[Int] =
  if      (xs.head <  ys.head)   xs.head #:: merge1(xs.tail, ys)
  else if (xs.head == ys.head)   xs.head #:: merge1(xs.tail, ys.tail)
  else  /* xs.head >  ys.head */ ys.head #:: merge1(xs, ys.tail)

The source code for this article is available on GitHub:

« | »

Talkback

* Copy This Password *

* Type Or Paste Password Here *