spark-shell을 이용한 "로지스틱 회귀 분석을 이용한 멀티 클래스 분류" 예이다.
다중 클래스 분류 문제를 트레이닝하고 예측하기 위해 이진 로지스틱 회귀를 다항 로지스틱 회귀로 일반화할 수 있다.
예를 들어 K개의 가능한 결과에 대해 결과 중 하나를 피벗으로 선택하고 다른 K-1개의 결과는 피벗 결과에 대해 개별적으로 회귀될 수 있다.
spark.mllib에서 첫 번째 클래스 0은 피벗(pivot) 클래스로 선택된다.
다중 클래스 분류 문제의 경우 알고리즘은 첫 번째 클래스에 대해 회귀된 k-1 이진 로지스틱 회귀 모델을 포함하는 다항 로지스틱 회귀 모델을 출력한다.
새로운 데이터 포인트가 주어지면 k-1 모델은 실행되고 가장 큰 확률을 가진 클래스가 예측 클래스로 선택된다.
이 섹션에서는 더 빠른 수렴을 위해 L-BFGS를 사용하는 로지스틱 회귀 분석을 사용해 분류하는 예를 보여준다.
LIVSVM 포맷으로 https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html에서
MNIST 데이터셋(https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2)을 로드하고 파싱한다.
import org.apache.spark.mllib.util.MLUtils
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "data/mnist.bz2")
다음처럼 데이터를 트레이닝 셋(75%)과 테스트 셋(25%)으로 나눈다.
val splits = data.randomSplit(Array(0.75, 0.25), seed = 12345L)
val training = splits(0).cache()
val test = splits(1)
트레이닝 알고리즘을 실행해 다중 클래스(이 데이터 셋의 경우는 10개이다)를 설정하여 모델을 구축한다.
분류 정확도를 높이려면 다음처럼 Boolean true 값을 사용해 데이터셋에 인터셉트를 추가(setIntercept)한 후 유효성을 검사(setValidateData)해야 한다.
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.setIntercept(true)
.setValidateData(true)
.run(training)
알고리즘이 setIntercept를 사용하여 인터셉트를 추가해야하는 경우 인터셉트를 true로 설정한다.
모델 구축 전에 알고리즘에 트레이닝 셋으로 유효성을 검사하려면 setValidateData 함수를 사용하여 값을 true로 설정해야한다.
다음처럼 기본 임계값을 지워 기본 설정으로 트레이닝하지 않는다.
model.clearThreshold()
테스트 셋의 원본 점수를 계산해서 이전에 언급한 성능 메트릭를 사용해 다음처럼 모델을 평가할 수 있다.
val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
}
평가를 할 수 있도록 여러 개의 메트릭을 초기화한다.
import org.apache.spark.mllib.evaluation.MulticlassMetrics
val metrics = new MulticlassMetrics(scoreAndLabels)
혼동 행렬를 구축한다.
println(metrics.confusionMatrix)
1466.0 1.0 4.0 2.0 3.0 11.0 18.0 1.0 11.0 4.0
0.0 1709.0 11.0 3.0 2.0 6.0 1.0 5.0 15.0 4.0
10.0 17.0 1316.0 24.0 22.0 8.0 20.0 17.0 26.0 8.0
3.0 9.0 38.0 1423.0 1.0 52.0 9.0 11.0 31.0 15.0
3.0 4.0 23.0 1.0 1363.0 4.0 10.0 7.0 5.0 43.0
19.0 7.0 11.0 50.0 12.0 1170.0 23.0 6.0 32.0 11.0
6.0 2.0 15.0 3.0 10.0 19.0 1411.0 2.0 8.0 2.0
4.0 7.0 10.0 7.0 14.0 4.0 2.0 1519.0 8.0 48.0
9.0 22.0 26.0 43.0 11.0 46.0 16.0 5.0 1268.0 8.0
6.0 3.0 5.0 23.0 39.0 8.0 0.0 60.0 14.0 1327.0
혼동 행렬에서 행렬의 각 컬럼은 예측 클래스의 인스턴스를 나타내는 반면,
각 라인은 실제 클래스의 인스턴스를 나타낸다(또는 그 반대).
이름은 시스템이 2개의 클래스를 혼동하고 있는지 쉽게 알 수 있게 한다는 사실에서 유래한다.
자세한 내용은 혼동 행렬(https://en.wikipedia.org/wiki/Confusion_matrix.Confusion)를 참조한다.
이제 모델의 성능을 판단하기 위해 전체 통계를 계산해보자.
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
// 레이블 당 정확도
val labels = metrics.labels
labels.foreach { l =>
println(s"Precision($l) = " + metrics.precision(l))
}
// 레이블 당 회수율
labels.foreach { l =>
println(s"Recall($l) = " + metrics.recall(l))
}
// 레이블 당 거짓 긍정 비율
labels.foreach { l =>
println(s"FPR($l) = " + metrics.falsePositiveRate(l))
}
// 레이블 당 F-측정 값
labels.foreach { l =>
println(s"F1-Score($l) = " + metrics.fMeasure(l))
}
이전 코드 세그먼트는 정확도, 정밀도, 회수율, 참 긍정 비율, 오 탐지율 및 F1 점수와 같은 성능 메트릭을 포함하는 다음 출력을 생성한다.
Summary Statistics
----------------------
scala> // 레이블 당 정확도
scala> val labels = metrics.labels
labels: Array[Double] = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
scala> labels.foreach { l =>
| println(s"Precision($l) = " + metrics.precision(l))
| }
Precision(0.0) = 0.9606815203145478
Precision(1.0) = 0.9595732734418866
Precision(2.0) = 0.9019876627827279
Precision(3.0) = 0.9012032932235592
Precision(4.0) = 0.922816519972918
Precision(5.0) = 0.8810240963855421
Precision(6.0) = 0.9344370860927153
Precision(7.0) = 0.9301898346601347
Precision(8.0) = 0.8942172073342737
Precision(9.0) = 0.9027210884353741
scala> // 레이블 당 회수율
scala> labels.foreach { l =>
| println(s"Recall($l) = " + metrics.recall(l))
| }
Recall(0.0) = 0.9638395792241946
Recall(1.0) = 0.9732346241457859
Recall(2.0) = 0.896457765667575
Recall(3.0) = 0.8938442211055276
Recall(4.0) = 0.9316473000683527
Recall(5.0) = 0.87248322147651
Recall(6.0) = 0.9546684709066305
Recall(7.0) = 0.9359211337030191
Recall(8.0) = 0.8720770288858322
Recall(9.0) = 0.8936026936026936
scala> // 레이블 당 거짓 긍정 비율
scala> labels.foreach { l =>
| println(s"FPR($l) = " + metrics.falsePositiveRate(l))
| }
FPR(0.0) = 0.004392386530014641
FPR(1.0) = 0.005363128491620112
FPR(2.0) = 0.010428060964048714
FPR(3.0) = 0.011479873427036574
FPR(4.0) = 0.008310249307479225
FPR(5.0) = 0.011416184971098265
FPR(6.0) = 0.0072246953221922205
FPR(7.0) = 0.00840831981118159
FPR(8.0) = 0.010927369417935456
FPR(9.0) = 0.010441004672897197
scala> // 레이블 당 F-측정 값
scala> labels.foreach { l =>
| println(s"F1-Score($l) = " + metrics.fMeasure(l))
| }
F1-Score(0.0) = 0.9622579586478502
F1-Score(1.0) = 0.966355668645745
F1-Score(2.0) = 0.8992142125042706
F1-Score(3.0) = 0.8975086723431095
F1-Score(4.0) = 0.9272108843537414
F1-Score(5.0) = 0.876732858748595
F1-Score(6.0) = 0.9444444444444444
F1-Score(7.0) = 0.933046683046683
F1-Score(8.0) = 0.883008356545961
F1-Score(9.0) = 0.8981387478849409
이제 전체 통계, 즉 요약 통계를 계산하자.
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
이전 코드는 가중치 정밀도, 회수율, F1 점수, 거짓 긍정 비율을 다음처럼 출력한다.
Weighted precision: 0.920104303076327
Weighted recall: 0.9203609775377117
Weighted F1 score: 0.9201934861645358
Weighted false positive rate: 0.008752250453215607
전체 통계에 따르면 모델의 정확도는 92%이상이다.
그러나 랜덤 포레스트(RF)와 같은 더 좋은 알고리즘을 사용하면 성능이 향상된다.
'scala' 카테고리의 다른 글
[spark] 기본 파티션 개수 (0) | 2018.10.12 |
---|---|
[spark] "랜덤 포레스트를 이용한 MNIST 데이터셋 분류" 예 (0) | 2018.06.01 |
[spark] spark-shell 메모리/cpu 설정 (0) | 2018.05.31 |
[spark] 스파크 머신 러닝(ML) api을 사용하여 파이프 라인 개발하기 - 유방암 가능성 예측 (0) | 2018.05.31 |
[spark] 머신러닝 - SGD(선형 회귀 기반 알고리즘) 적용 예 (0) | 2018.05.30 |