스칼라에서의 꼬리 재귀와 tailrec를 공부한다.




머리 재귀 예시이다. 일면 stack over flow가 발생할 수 있다. 


def sum1(list: List[Int]): Int = list match {
case Nil => 0
case t :: tail => t + sum1(tail)
}
println(sum1((1 to 2).toList))
//println(sum1((1 to 1000000).toList)) //Exception in thread "main" java.lang.StackOverflowError

def sum2(list: List[Int]): Int = {
if (list.isEmpty) 0
else list.head + sum2(list.tail)
}
println(sum2((1 to 2).toList))
//println(sum2((1 to 1000000).toList)) //Exception in thread "main" java.lang.StackOverflowError

결과는 3이다.





중간값을 가진 꼬리 재귀로 구현해 보자. 



def sum3(list: List[Int], acc: Int): Int = {
if (list.isEmpty) acc
else sum3(list.tail, list.head + acc)
}
println(sum3(((1 to 2).toList), 0))
println(sum3((1 to 1000000).toList, 0))

def sum4(list: List[Int], acc: Int): Int = list match {
case Nil => acc
case h :: tail => sum4(tail, h + acc)
}
println(sum4(((1 to 2).toList), 0))
println(sum4((1 to 1000000).toList, 0))

결과는 다음과 같다. 중간 값을 저장했기 때문에 stack over flow가 발생하지 않았다. 


3

1784293664

3

1784293664






중간값을 꼬리 재귀에 조금만 수정해서 앞의 entry 함수를 하나 만들어본다.

def tailrecSum(l: List[Int]): Int = {
def sum5(list: List[Int], acc: Int): Int = list match {
case Nil => acc
case x :: tail => sum5(tail, acc + x)
}
sum5(l, 0)
}

println(tailrecSum((1 to 1000000).toList))


결과는 다음과 같다. 


1784293664








스칼라에는 꼬리 재귀 최적화 기능을 가지고 있다. 


@tailrec라고 재귀 함수 앞에 붙이면 스칼라 컴파일러에 꼬리 재귀가 있으니 최적화라고 알려준다.


@tailrec를 사용하려면 다음 import문을 사용한다.

import scala.annotation.tailrec




앞에 실행했던 예시는 아래와 같이 sum5앞에 @tailrec를 붙였다.

def tailrecSum(l: List[Int]): Int = { @tailrec
def sum5(list: List[Int], acc: Int): Int = list match {
case Nil => acc
case x :: tail => sum5(tail, acc + x)
}
sum5(l, 0)
}

println(tailrecSum((1 to 1000000).toList))



@tailrec는 아무 때나 최적화되지 않고, 심지어 에러가 발생할 수 있으니. 신중히 써야 할 수 있다.



아래 코드는 Recursive call not in position 이라는 컴파일 에러가 발생한다.

@tailrec
def factorial(i: BigInt): BigInt = {
if (i == 1) i
else i * factorial(i - 1)
}

for (i <- 1 to 10)
println(s"$i:\t${factorial(i)}")


재귀 함수가 public이면, 상속받아서 쓸 수 있기 때문에 쓰지 못하도록 에러를 발생시킨다.


class Printer(msg: String) {
@tailrec
def printMessageNTimes(n: Int): Unit = {
if(n > 0){
println(msg)
printMessageNTimes(n - 1)
}
}
}

new Printer("m").printMessageNTimes(10000)

could not optimize @tailrec annotated method printMessageNTimes: it is neither private nor final so can be overridden



final 메소드로 수정하니. 정상적으로 동작한다.

class Printer(msg: String) {
@tailrec
final def printMessageNTimes(n: Int): Unit = {
if(n > 0){
println(msg)
printMessageNTimes(n - 1)
}
}
}

new Printer("m").printMessageNTimes(10000)




트램폴린(trampoline)은 여러 함수가 다른 함수를 호출하여 이루어지는 재귀를 말한다. 


X를 호출하면 A를 호출했다가 A의 내부에서 B를 호출했고 B의 내부에서 B를 호출하면서. 계속 왔다 갔다하는 형태의 재귀를 말한다. 



스칼라에는 TailCall(https://www.scala-lang.org/api/current/scala/util/control/TailCalls$.html) 이라를 객체가 있으니, 이를 참조한다.


import scala.util.control.TailCalls._

def isEven(xs: List[Int]): TailRec[Boolean] =
  if (xs.isEmpty) done(true) else tailcall(isOdd(xs.tail))

def isOdd(xs: List[Int]): TailRec[Boolean] =
 if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))

isEven((1 to 100000).toList).result

def fib(n: Int): TailRec[Int] =
  if (n < 2) done(n) else for {
    x <- tailcall(fib(n - 1))
    y <- tailcall(fib(n - 2))
  } yield (x + y)

fib(40).result


'scala' 카테고리의 다른 글

[scala] try-catch/Try-match/Either/Validation  (0) 2016.12.06
[scala] 부분 적용 함수 / 커링 / 부분 함수  (0) 2016.12.05
[scala] Future말고 Promise  (0) 2016.11.27
[scala] Future 2  (0) 2016.11.23
[scala] Future 1  (0) 2016.11.22
Posted by '김용환'
,