Scala concepts: Stack Safety

Stack safety is present where a function cannot crash due to overflowing the limit of number of recursive calls.

This function will work for n = 5, but will not work for n = 2000 (crash with java.lang.StackOverflowError) - however there is a way to fix it :-)

In Scala Algorithms, we try to write the algorithms in a stack-safe way, where possible, so that when you use the algorithms, they will not crash on large inputs. However, stack-safe implementations are often more complex, and in some cases, overly complex, for the task at hand.

def sum(from: Int, until: Int): Int =
  if (from == until) until else from + sum(from + 1, until)

def thisWillSucceed: Int = sum(1, 5)

def thisWillFail: Int = sum(1, 300)

A fix using Tail Recursion

In Scala, tail recursion enables you to rewrite a mutable structure such as a while-loop, into an immutable algorithm.

So, so long as you can rewrite a recursive function into an iterative function, you're good to go!

However, do not fret, there is a way to rewrite even recursive functions into stack-safe ones, called 'Defunctionalisation', which we must cover in future.

def sum(from: Int, until: Int): Int = {
  @scala.annotation.tailrec
  def rec(current: Int, sumSoFar: Int): Int = {
    if (current > until) sumSoFar else rec(current + 1, sumSoFar + current)
  }
  rec(from, 0)
}

assert(sum(1, 5) == 15)

A fix using foldLeft and foldRight

Tail recursion can be expressed in FoldLeft-FoldRight, and vice-versa; which is better depends on how it looks in the end. It's also related to State machine

def sum(from: Int, until: Int): Int = { (from to until).foldLeft(0)(_ + _) }

assert(sum(1, 5) == 15)

A fix using Range

Here, we benefit from Scala's utilities that allow specialisation

def sum(from: Int, until: Int): Int = { (from to until).foldLeft(0)(_ + _) }

assert(sum(1, 5) == 15)

A fix using a direct formula

Ideal if you can avoid an iteration and come up with a formula that runs in O(1)

def sum(from: Int, until: Int): Int = { (from to until).foldLeft(0)(_ + _) }

assert(sum(1, 5) == 15)