spark-shell을 이용해  "랜덤 포레스트를 이용한 MNIST 데이터셋 분류" 예를 공부한다. 



이 섹션에서는 랜덤 포레스트를 사용한 분류 예를 소개할 것이다. 코드를 단계별로 분석 해결책을 쉽게 이해할 수 있다.


1단계. LIVSVM 포맷으로 MNIST 데이터셋을 로드하고 파싱한다.



import org.apache.spark.mllib.util.MLUtils

// LIBSVM 포맷의 트레이닝 데이터를 로드한다.

val data = MLUtils.loadLibSVMFile(spark.sparkContext, "data/mnist.bz2")




2단계. 트레이닝과 테스트 셋을 준비한다.

데이터를 트레이닝 셋(75%)과 테스트 셋(25%)으로 나누고 재현하기 위해 다음처럼 시드를 설정한다.



val splits = data.randomSplit(Array(0.75, 0.25), seed = 12345L)

val training = splits(0).cache()

val test = splits(1)




모델을 구축하기 위해 트레이닝 알고리즘을 실행한다.


빈 categoricalFeaturesInfo를 사용해 랜덤 포레스트 모델을 트레이닝시킨다. 모든 피쳐가 데이터셋에서 연속적이기 때문에 관련 작업이 필요하다.




import org.apache.spark.mllib.tree.RandomForest



val numClasses = 10 //MNIST 데이터 셋의 클래스의 개수

val categoricalFeaturesInfo = Map[Int, Int]()

val numTrees = 50 // 실제 상황에서는 더 큰 수를 사용한다. 값이 더 클수록 좋다.

val featureSubsetStrategy = "auto" // 알고리즘을 선택한다.

val impurity = "gini" // 이전에 설명한 랜덤 포레스트를 설명한 노트를 살펴보라.

val maxDepth = 30 // 실제 상황에서는 값이 더 클수록 좋다.

val maxBins = 32 // 실제 상황에서는 값이 더 클수록 좋다.

val model = RandomForest.trainClassifier(training, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)





랜덤 포레스트 모델을 트레이닝하는 것은 매우 자원이 소비되는 작업이다. 따라서 더 많은 메모리가 필요하므로 OOM이 발생되지 않도록 주의해야 한다. 



이전에 언급한 성능 메트릭을 사용해 다음처럼 모델을 평가할 수 있도록 테스트 셋의 원 점수를 계산한다.


val scoreAndLabels = test.map { point =>

 val score = model.predict(point.features)

 (score, point.label)

}




평가를 위해 다중 클래스에 대한 메트릭을 초기화한다.


import org.apache.spark.mllib.evaluation.MulticlassMetrics

val metrics = new MulticlassMetrics(scoreAndLabels)



혼동 행렬을 생성한다.



println(metrics.confusionMatrix)


1498.0  0.0     3.0     2.0     0.0     2.0     4.0     0.0     12.0    0.0

0.0     1736.0  8.0     1.0     2.0     1.0     2.0     4.0     0.0     2.0

7.0     0.0     1424.0  2.0     3.0     1.0     5.0     12.0    10.0    4.0

0.0     3.0     20.0    1507.0  0.0     19.0    2.0     13.0    19.0    9.0

3.0     0.0     5.0     0.0     1416.0  0.0     2.0     4.0     4.0     29.0

10.0    2.0     1.0     21.0    4.0     1272.0  14.0    3.0     13.0    1.0

6.0     2.0     0.0     0.0     1.0     9.0     1456.0  0.0     4.0     0.0

2.0     1.0     6.0     0.0     9.0     1.0     0.0     1578.0  8.0     18.0

2.0     6.0     7.0     9.0     5.0     10.0    5.0     2.0     1398.0  10.0

7.0     3.0     0.0     22.0    16.0    4.0     1.0     15.0    13.0    1404.0





이전 코드는 분류를 위해 다음과 같은 혼동 행렬을 출력한다.




이제 모델의 성능을 판단하기 위해 전체 통계를 계산하자.


정확도, 정밀도, 회수율, 참 긍정 비율, 거짓 긍정 비율, F1 점수와 같은 성능 메트릭을 포함하는 다음 출력을 생성한다.



val accuracy = metrics.accuracy

println("Summary Statistics")

println(s"Accuracy = $accuracy")

// 레이블 당 정확도

val labels = metrics.labels

labels.foreach { l =>

 println(s"Precision($l) = " + metrics.precision(l))

}

// 레이블 당 회수율

labels.foreach { l =>

 println(s"Recall($l) = " + metrics.recall(l))

}

// 레이블 당 거짓 긍정 비율

labels.foreach { l =>

 println(s"FPR($l) = " + metrics.falsePositiveRate(l))

}

// 레이블 당 F-측정 값

labels.foreach { l =>

 println(s"F1-Score($l) = " + metrics.fMeasure(l))





실제 실행 결과는 다음과 같다.


scala> val accuracy = metrics.accuracy

accuracy: Double = 0.967591067782096


scala> println("Summary Statistics")

Summary Statistics


scala> println(s"Accuracy = $accuracy")

Accuracy = 0.967591067782096


scala> // 레이블 당 정확도


scala> val labels = metrics.labels

labels: Array[Double] = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)


scala> labels.foreach { l =>

     |  println(s"Precision($l) = " + metrics.precision(l))

     | }

Precision(0.0) = 0.9758957654723127

Precision(1.0) = 0.9903023388476897

Precision(2.0) = 0.966078697421981

Precision(3.0) = 0.9635549872122762

Precision(4.0) = 0.9725274725274725

Precision(5.0) = 0.9643669446550417

Precision(6.0) = 0.9765258215962441

Precision(7.0) = 0.967504598405886

Precision(8.0) = 0.9439567859554355

Precision(9.0) = 0.950575490859851


scala> // 레이블 당 회수율


scala> labels.foreach { l =>

     |  println(s"Recall($l) = " + metrics.recall(l))

     | }

Recall(0.0) = 0.9848783694937541

Recall(1.0) = 0.9886104783599089

Recall(2.0) = 0.9700272479564033

Recall(3.0) = 0.946608040201005

Recall(4.0) = 0.9678742310321258

Recall(5.0) = 0.9485458612975392

Recall(6.0) = 0.9851150202976996

Recall(7.0) = 0.9722735674676525

Recall(8.0) = 0.9614855570839065

Recall(9.0) = 0.9454545454545454


scala> // 레이블 당 거짓 긍정 비율


scala> labels.foreach { l =>

     |  println(s"FPR($l) = " + metrics.falsePositiveRate(l))

     | }

FPR(0.0) = 0.0027086383601756954

FPR(1.0) = 0.001266294227188082

FPR(2.0) = 0.003646175162254795

FPR(3.0) = 0.004194569136801825

FPR(4.0) = 0.0029158769499927103

FPR(5.0) = 0.0033959537572254336

FPR(6.0) = 0.0025541852149164415

FPR(7.0) = 0.003909131140286178

FPR(8.0) = 0.006046477744590952

FPR(9.0) = 0.005330023364485981


scala> // 레이블 당 F-측정 값


scala> labels.foreach { l =>

     |  println(s"F1-Score($l) = " + metrics.fMeasure(l))

     | }

F1-Score(0.0) = 0.9803664921465969

F1-Score(1.0) = 0.9894556853804503

F1-Score(2.0) = 0.9680489462950375

F1-Score(3.0) = 0.9550063371356146

F1-Score(4.0) = 0.9701952723535457

F1-Score(5.0) = 0.956390977443609

F1-Score(6.0) = 0.9808016167059616

F1-Score(7.0) = 0.9698832206515058

F1-Score(8.0) = 0.952640545144804

F1-Score(9.0) = 0.9480081026333558


scala>






이제 전체 통계를 다음처럼 계산하자.


println(s"Weighted precision: ${metrics.weightedPrecision}")

println(s"Weighted recall: ${metrics.weightedRecall}")

println(s"Weighted F1 score: ${metrics.weightedFMeasure}")

println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")

val testErr = scoreAndLabels.filter(r => r._1 != r._2).count.toDouble / test.count()

println("Accuracy = " + (1-testErr) * 100 + " %")




이전 코드는 가중치 정밀도, 회수율, F1 점수, 거짓 긍정 비율을 포함하는 다음 출력을 출력한다.


Overall statistics

----------------------------

Weighted precision: 0.9676041167963592

Weighted recall: 0.9675910677820959

Weighted F1 score: 0.9675700010426889

Weighted false positive rate: 0.03240893221790396

Accuracy = 96.7591067782096 %




전체 통계에 따르면 모형의 정확도는 96%이상 로지스틱 회귀 분석보다 우수하다. 


그러나 모델을 잘 튜닝하면 더욱 개선될 수 있다.


Posted by '김용환'
,