[ML] 하나의 class만 학습시켜서 불균형 데이터 예측하기

2020. 11. 12. 17:44ML&DL

Binary classification을 할 때 class의 데이터가 매우 불균형하거나 class에 한 개의 데이터만 있고 나머지 데이터가 없는 경우들이 있습니다. 이런 경우는 하나의 class만 학습시켜서 분류를 할 수 있는데 그 중 사용되는 방법 중 하나가 OCSVM입니다.

 

OCSVM

 

OCSVM은 One-class SVM의 줄임말로 SVM(Support vector machine)과 달리 비지도 학습(unsupervised learning)입니다. 주어진 데이터를 잘 설명할 수 있는 최적의 support vector를 구하고 이 영역 밖의 데이터들은 outlier로 간주하는 방식으로 이상치 탐지, 이미지 검색, 문서/ 텍스트 분류 등에 사용되었습니다. 

 

 

출처: https://www.researchgate.net/figure/One-class-SVM-boundary-and-outlier-detection_fig5_281455041

위 그림과 같이 알고리즘은 초록색 원의 샘플 데이터에 대해 학습한 후 최적의 support vecotr를 찾아내고 그 외 벗어나는 구간의 데이터를 outlier라고 간주하는 것입니다.

 

 

 

코드

R 함수로 구현해보겠습니다.

 

데이터는 ROSE 패키지 내장데이터인 hacide입니다. hacide데이터는 불균형 이항 분류를 위한 모의 데이터 세트입니다. 

 

library(e1071)  # SVM 모델을 구현하기 위한 패키지
library(ROSE)   # 데이터를 사용하기 위한 패키지 
library(caret)

data(hacide)

data = rbind(hacide.train,hacide.test)
str(data)
table(data$cls)

 

데이터의 독립변수는 x1, x2 두 개, 타겟 변수 cls는 0이 98%, 1이 2% 비율로 차지하고 있습니다. 

 

# 한 쪽 데이터만 학습시킨 후 모형이 잘 예측했는지 확인하기 위해 TRUE, FALSE로 변경 
# True, False 로 변경해주지 않으면 모형 적용 후 confusionmatrix를 확인할 때 행, 열의 이름이 달라 에러 발생
data$class[data$cls=="0"] = "TRUE" 
data$class[data$cls!="0"] = "FALSE"

# TRUE와 FALSE인 데이터 나눔 
data_True<-subset(data,class=="TRUE")
data_False<-subset(data,class=="FALSE")

# train/test data split
inTrain<-createDataPartition(1:nrow(data_True),p=0.6,list=FALSE)

# train 데이터 생성 
train_x<-data_True[inTrain,2:3]
train_y<-data_True[inTrain,4]

# test 데이터 생성 
test<-rbind(data_True[-inTrain,],data_False)

test_x <-test[,2:3]
test_y <-test[,4]

 

타겟 변수가 0인 것만 학습시키기 위해 train 데이터의 타겟 변수는 모두 0에 해당하는 TRUE 가 들어가도록, test 데이터는 0, 1 모두 들어가도록 나눠줍니다. 

 

svm.model<-svm(train_x,y=NULL,
               type='one-classification',
               nu=0.10,
               scale=TRUE,
               kernel="radial")  # 방사 kernel 함수를 이용하여 적용 

svm.predtrain<-predict(svm.model,train_x)
svm.predtest<-predict(svm.model,test_x)

 

기본 svm 모델에 type = "one-classfication" 을 입력하여 한 쪽 class만 학습하는 모델을 적용합니다.  

 

confTrain<-table(Predicted=svm.predtrain,Reference=train_y)
confTest<-table(Predicted=svm.predtest,Reference=test_y )

confusionMatrix(confTest,positive='TRUE')

 

모델 적용 결과 데이터가 불균형할 때 확인하는 정확도 지표인 Balanced Accuracy가 0.93으로 높은 정확도를 보이는 것으로 확인이되고 confusion matrix 결과 cls가 1인데 0으로 예측한 경우가 25개 중 1개로 , cls가 0인데 1로 예측한 경우가 488개 중 48개로 대체로 잘 예측한 것으로 보입니다. 

 

 

참고

[1] OCSVM

[2] Balance Accuracy

728x90