코세라 Scala 강의 중 MergeSort 예제에서 조금 이해하기 쉽게 분리한 예제이다.
https://www.coursera.org/learn/progfun1/lecture/0uFfe/lecture-5-2-pairs-and-tuples
object MergeSort {
def merge(left: List[Int], right: List[Int]): List[Int] = (left, right) match {
case (l, Nil) => l
case (Nil, r) => r
case (leftHead :: leftTail, rightHead :: rightTail) =>
if (leftHead < rightHead) leftHead::merge(leftTail, right)
else rightHead :: merge(left, rightTail)
}
def mergeSort(list: List[Int]): List[Int] = {
val n = list.length / 2
if (n == 0) list
else {
val (left, right) = list splitAt n
merge(mergeSort(left), mergeSort(right))
}
}
}
결과
scala> MergeSort.mergeSort(List())
res17: List[Int] = List()
scala> MergeSort.mergeSort(List(100,50,120,19))
res16: List[Int] = List(19, 50, 100, 120)
Int 타입을 T 타입으로 (order:(T, T) => scala.Boolean)을 추가해서 compare를 할 수 있게 한다.
object MergeSort {
def mergeSort[T](list: List[T])(order: (T, T) => scala.Boolean): List[T] = {
val n = list.length / 2
if (n == 0) list
else {
def merge(left: List[T], right: List[T]) : List[T] = (left, right) match {
case (l, Nil) => l
case (Nil, r) => r
case (leftHead :: leftTail, rightHead :: rightTail) =>
if (order(leftHead, rightHead)) leftHead::merge(leftTail, right)
else rightHead :: merge(left, rightTail)
}
val (left, right) = list splitAt n
merge(mergeSort(left)(order), mergeSort(right)(order))
}
}
}
object Main extends App {
println(MergeSort.mergeSort(List(100,50,120,19))((x: Int, y: Int) => (x < y)))
println(MergeSort.mergeSort(List(100,50,120,19))((x: Int, y: Int) => (x > y)))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben"))((x: String, y: String) => (x.compareTo(y) < 0)))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben"))((x: String, y: String) => (x.compareTo(y) > 0)))
}
결과는 다음과 같다.
List(19, 50, 100, 120)
List(120, 100, 50, 19)
언제나 스칼라는 타입을 숨길 수 있다.
List(100,50,120,19))((x, y) => (x > y))
여기에 math.Ordering을 임포트해본다.
import math.Ordering
object MergeSort {
def mergeSort[T](list: List[T])(order: Ordering[T]): List[T] = {
val n = list.length / 2
if (n == 0) list
else {
def merge(left: List[T], right: List[T]) : List[T] = (left, right) match {
case (l, Nil) => l
case (Nil, r) => r
case (leftHead :: leftTail, rightHead :: rightTail) =>
if (order.lt(leftHead, rightHead)) leftHead::merge(leftTail, right)
else rightHead :: merge(left, rightTail)
}
val (left, right) = list splitAt n
merge(mergeSort(left)(order), mergeSort(right)(order))
}
}
}
object Main extends App {
println(MergeSort.mergeSort(List(100,50,120,19))(Ordering.Int))
println(MergeSort.mergeSort(List(100,50,120,19))(Ordering.Int))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben"))(Ordering.String))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben"))(Ordering.String))
}
결과는 이전과 동일하다.
여기에 implicit을 추가해본다. 그리고 implicit을 쓰면 적당히 코드 삭제가 가능하다.
import math.Ordering
object MergeSort {
def mergeSort[T](list: List[T])(implicit order: Ordering[T]): List[T] = {
val n = list.length / 2
if (n == 0) list
else {
def merge(left: List[T], right: List[T]) : List[T] = (left, right) match {
case (l, Nil) => l
case (Nil, r) => r
case (leftHead :: leftTail, rightHead :: rightTail) =>
if (order.lt(leftHead, rightHead)) leftHead::merge(leftTail, right)
else rightHead :: merge(left, rightTail)
}
val (left, right) = list splitAt n
merge(mergeSort(left), mergeSort(right))
}
}
}
object Main extends App {
println(MergeSort.mergeSort(List(100,50,120,19)))
println(MergeSort.mergeSort(List(100,50,120,19)))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben")))
println(MergeSort.mergeSort(List("samuel","jack","juno","ben")))
}
역시 결과가 동일하다.