spark-shell을 이용한 "로지스틱 회귀 분석을 이용한 멀티 클래스 분류" 예이다. 


다중 클래스 분류 문제를 트레이닝하고 예측하기 위해 이진 로지스틱 회귀를 다항 로지스틱 회귀로 일반화할 수 있다. 



예를 들어 K개의 가능한 결과에 대해 결과 중 하나를 피벗으로 선택하고 다른 K-1개의 결과는 피벗 결과에 대해 개별적으로 회귀될 수 있다. 


spark.mllib에서 첫 번째 클래스 0은 피벗(pivot) 클래스로 선택된다.


다중 클래스 분류 문제의 경우 알고리즘은 첫 번째 클래스에 대해 회귀된 k-1 이진 로지스틱 회귀 모델을 포함하는 다항 로지스틱 회귀 모델을 출력한다.


 새로운 데이터 포인트가 주어지면 k-1 모델은 실행되고 가장 큰 확률을 가진 클래스가 예측 클래스로 선택된다. 


이 섹션에서는 더 빠른 수렴을 위해 L-BFGS를 사용하는 로지스틱 회귀 분석을 사용해 분류하는 예를 보여준다.





LIVSVM 포맷으로 https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html에서 

MNIST 데이터셋(https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2)을 로드하고 파싱한다.




import org.apache.spark.mllib.util.MLUtils

val data = MLUtils.loadLibSVMFile(spark.sparkContext, "data/mnist.bz2")



다음처럼 데이터를 트레이닝 셋(75%)과 테스트 셋(25%)으로 나눈다.



val splits = data.randomSplit(Array(0.75, 0.25), seed = 12345L)

val training = splits(0).cache()

val test = splits(1)




트레이닝 알고리즘을 실행해 다중 클래스(이 데이터 셋의 경우는 10개이다)를 설정하여 모델을 구축한다. 



분류 정확도를 높이려면 다음처럼 Boolean true 값을 사용해 데이터셋에 인터셉트를 추가(setIntercept)한 후 유효성을 검사(setValidateData)해야 한다.


import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS

val model = new LogisticRegressionWithLBFGS()

          .setNumClasses(10)

          .setIntercept(true)

          .setValidateData(true)

          .run(training)





알고리즘이 setIntercept를 사용하여 인터셉트를 추가해야하는 경우 인터셉트를 true로 설정한다. 



모델 구축 전에 알고리즘에 트레이닝 셋으로 유효성을 검사하려면 setValidateData 함수를 사용하여 값을 true로 설정해야한다.




다음처럼 기본 임계값을 지워 기본 설정으로 트레이닝하지 않는다.


model.clearThreshold()





테스트 셋의 원본 점수를 계산해서 이전에 언급한 성능 메트릭를 사용해 다음처럼 모델을 평가할 수 있다.



val scoreAndLabels = test.map { point =>

 val score = model.predict(point.features)

 (score, point.label)

}



평가를 할 수 있도록 여러 개의 메트릭을 초기화한다.



import org.apache.spark.mllib.evaluation.MulticlassMetrics

val metrics = new MulticlassMetrics(scoreAndLabels)




혼동 행렬를 구축한다.



println(metrics.confusionMatrix)


1466.0  1.0     4.0     2.0     3.0     11.0    18.0    1.0     11.0    4.0

0.0     1709.0  11.0    3.0     2.0     6.0     1.0     5.0     15.0    4.0

10.0    17.0    1316.0  24.0    22.0    8.0     20.0    17.0    26.0    8.0

3.0     9.0     38.0    1423.0  1.0     52.0    9.0     11.0    31.0    15.0

3.0     4.0     23.0    1.0     1363.0  4.0     10.0    7.0     5.0     43.0

19.0    7.0     11.0    50.0    12.0    1170.0  23.0    6.0     32.0    11.0

6.0     2.0     15.0    3.0     10.0    19.0    1411.0  2.0     8.0     2.0

4.0     7.0     10.0    7.0     14.0    4.0     2.0     1519.0  8.0     48.0

9.0     22.0    26.0    43.0    11.0    46.0    16.0    5.0     1268.0  8.0

6.0     3.0     5.0     23.0    39.0    8.0     0.0     60.0    14.0    1327.0






혼동 행렬에서 행렬의 각 컬럼은 예측 클래스의 인스턴스를 나타내는 반면, 


각 라인은 실제 클래스의 인스턴스를 나타낸다(또는 그 반대). 


이름은 시스템이 2개의 클래스를 혼동하고 있는지 쉽게 알 수 있게 한다는 사실에서 유래한다. 


자세한 내용은 혼동 행렬(https://en.wikipedia.org/wiki/Confusion_matrix.Confusion)를 참조한다.






이제 모델의 성능을 판단하기 위해 전체 통계를 계산해보자.



val accuracy = metrics.accuracy

println("Summary Statistics")

println(s"Accuracy = $accuracy")

// 레이블 당 정확도

val labels = metrics.labels

labels.foreach { l =>

 println(s"Precision($l) = " + metrics.precision(l))

}

// 레이블 당 회수율

labels.foreach { l =>

 println(s"Recall($l) = " + metrics.recall(l))

}

// 레이블 당 거짓 긍정 비율

labels.foreach { l =>

 println(s"FPR($l) = " + metrics.falsePositiveRate(l))

}

// 레이블 당 F-측정 값

labels.foreach { l =>

 println(s"F1-Score($l) = " + metrics.fMeasure(l))

}




이전 코드 세그먼트는 정확도, 정밀도, 회수율, 참 긍정 비율, 오 탐지율 및 F1 점수와 같은 성능 메트릭을 포함하는 다음 출력을 생성한다.


Summary Statistics

----------------------

scala> // 레이블 당 정확도


scala> val labels = metrics.labels

labels: Array[Double] = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)


scala> labels.foreach { l =>

     |  println(s"Precision($l) = " + metrics.precision(l))

     | }

Precision(0.0) = 0.9606815203145478

Precision(1.0) = 0.9595732734418866

Precision(2.0) = 0.9019876627827279

Precision(3.0) = 0.9012032932235592

Precision(4.0) = 0.922816519972918

Precision(5.0) = 0.8810240963855421

Precision(6.0) = 0.9344370860927153

Precision(7.0) = 0.9301898346601347

Precision(8.0) = 0.8942172073342737

Precision(9.0) = 0.9027210884353741


scala> // 레이블 당 회수율


scala> labels.foreach { l =>

     |  println(s"Recall($l) = " + metrics.recall(l))

     | }

Recall(0.0) = 0.9638395792241946

Recall(1.0) = 0.9732346241457859

Recall(2.0) = 0.896457765667575

Recall(3.0) = 0.8938442211055276

Recall(4.0) = 0.9316473000683527

Recall(5.0) = 0.87248322147651

Recall(6.0) = 0.9546684709066305

Recall(7.0) = 0.9359211337030191

Recall(8.0) = 0.8720770288858322

Recall(9.0) = 0.8936026936026936


scala> // 레이블 당 거짓 긍정 비율


scala> labels.foreach { l =>

     |  println(s"FPR($l) = " + metrics.falsePositiveRate(l))

     | }

FPR(0.0) = 0.004392386530014641

FPR(1.0) = 0.005363128491620112

FPR(2.0) = 0.010428060964048714

FPR(3.0) = 0.011479873427036574

FPR(4.0) = 0.008310249307479225

FPR(5.0) = 0.011416184971098265

FPR(6.0) = 0.0072246953221922205

FPR(7.0) = 0.00840831981118159

FPR(8.0) = 0.010927369417935456

FPR(9.0) = 0.010441004672897197


scala> // 레이블 당 F-측정 값


scala> labels.foreach { l =>

     |  println(s"F1-Score($l) = " + metrics.fMeasure(l))

     | }

F1-Score(0.0) = 0.9622579586478502

F1-Score(1.0) = 0.966355668645745

F1-Score(2.0) = 0.8992142125042706

F1-Score(3.0) = 0.8975086723431095

F1-Score(4.0) = 0.9272108843537414

F1-Score(5.0) = 0.876732858748595

F1-Score(6.0) = 0.9444444444444444

F1-Score(7.0) = 0.933046683046683

F1-Score(8.0) = 0.883008356545961

F1-Score(9.0) = 0.8981387478849409




이제 전체 통계, 즉 요약 통계를 계산하자.


println(s"Weighted precision: ${metrics.weightedPrecision}")

println(s"Weighted recall: ${metrics.weightedRecall}")

println(s"Weighted F1 score: ${metrics.weightedFMeasure}")

println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") 




이전 코드는 가중치 정밀도, 회수율, F1 점수, 거짓 긍정 비율을 다음처럼 출력한다.


Weighted precision: 0.920104303076327

Weighted recall: 0.9203609775377117

Weighted F1 score: 0.9201934861645358

Weighted false positive rate: 0.008752250453215607




전체 통계에 따르면 모델의 정확도는 92%이상이다. 


그러나 랜덤 포레스트(RF)와 같은 더 좋은 알고리즘을 사용하면 성능이 향상된다.



Posted by '김용환'
,