continuations - why scala doesn't make tail call optimization? -
just playing continuations. goal create function receive function parameter, , execution amount - , return function apply parameter given amount times.
the implementation looks pretty obvious
def n_times[t](func:t=>t,count:int):t=>t = { @tailrec def n_times_cont(cnt:int, continuation:t=>t):t=>t= cnt match { case _ if cnt < 1 => throw new illegalargumentexception(s"count wrong $count") case 1 => continuation case _ => n_times_cont(cnt-1,i=>continuation(func(i))) } n_times_cont(count, func) } def inc (x:int) = x+1 val res1 = n_times(inc,1000)(1) // works ok, returns 1001 val res = n_times(inc,10000000)(1) // fails but there no problem - code fails stackoverflow error. why there no tail-call optimization here?
i'm running in eclipse using scala plugin, , returns exception in thread "main" java.lang.stackoverflowerror @ scala.runtime.boxesruntime.boxtointeger(unknown source) @ task_mult$$anonfun$1.apply(task_mult.scala:25) @ task_mult$$anonfun$n_times_cont$1$1.apply(task_mult.scala:18)
p.s.
f# code, direct translation, working without issues
let n_times_cnt func count = let rec n_times_impl count' continuation = match count' | _ when count'<1 -> failwith "wrong count" | 1 -> continuation | _ -> n_times_impl (count'-1) (func >> continuation) n_times_impl count func let inc x = x+1 let res = (n_times_cnt inc 10000000) 1 printfn "%o" res
the scala standard library has implementation of trampolines in scala.util.control.tailcalls. revisiting implementation... when build nested calls continuation(func(t)), tail calls, not optimized compiler. so, let's build t => tailrec[t], stack frames replaced objects in heap. return function take argument , pass trampolined function:
import util.control.tailcalls._ def n_times_trampolined[t](func: t => t, count: int): t => t = { @annotation.tailrec def n_times_cont(cnt: int, continuation: t => tailrec[t]): t => tailrec[t] = cnt match { case _ if cnt < 1 => throw new illegalargumentexception(s"count wrong $count") case 1 => continuation case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t)))) } val lifted : t => tailrec[t] = t => done(func(t)) t => n_times_cont(count, lifted)(t).result }
Comments
Post a Comment