MACHINE LEARNING/Machine Learning Library

ML(머신러닝) : Logistic Regression 개념 정리 (sklearn 으로 classifier 생성 및 Confusion Matrix = cm 만들기)

신강희 2024. 4. 20. 01:13
728x90

< Logistic Regression 이란?? >

- Logistic Regression은 기계 학습과 통계에서 사용되는 통계적 분류 방법

- 주로 두 개 이상의 클래스 중 하나에 속하는 경우를 예측하는 데 사용된다.

- 주로 이진 분류(binary classification) 문제에 적용되며, 예를 들어 스팸 메일 여부 판별, 질병 진단 등 다양한 분야에서 활용된다.

- 이 방법은 선형 회귀(Linear Regression)와 비슷해 보이지만, 출력 값이 0과 1 사이의 확률값으로 제한되며, S자 형태의 로지스틱 함수를 사용하고, 이를 통해 예측값을 확률로 해석할 수 있다.

- 간단히 말하면, Logistic Regression은 주어진 입력 변수를 기반으로 특정 클래스에 속할 확률을 예측하는 모델

 

분류에 사용한다. (Classification)

예) 나이대별로 이메일을 클릭해서 열지 말지를 분류해 보자.

 

이메일 클릭을 할 사람과 안할 사람으로 분류할 것이다.

빨간점이 바로 데이터이며,

액션의 0 과 1 이 바로 레이블이다.

레이블이 있다는것은, 수퍼바이저드 러닝 이라는 뜻

 

이렇게 비슷하게 생긴 함수가 이미존재한다. 이름은 sigmoid function

따라서 리니어 리그레션 식을, y 값을 시그모이드에 대입해서, 일차방정식으로 만들면 다음과 같아진다.

위와 같은 식을 가진 regression 을, Logistic Regression이라 한다.

이제 우리는, 이를 가지고 두개의 클래스로 분류할 수 있다. ( 클릭을 한다, 안한다 두개로.)

 

확률로 나타낼 수 있게 되었다.

p는 확률값을 나타낸다.

 

20대는 클릭할 확률이 0.7%, 40대는 85%, 50대는 99.4%

이 확률값은, 위에서의 시그모이드 함수를 적용한 식을 통해 나온 값임을 기억한다.

최종 예측값은, 0.5를 기준으로 두개의 부류로 나눈다. 그 값은 0 과 1 이다.

 

< 이제 실제 예제문을 통해  Logistic Regression 실행해 보자 >

 

# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

 

# 새로운 데이터를 불러와 실행

df = pd.read_csv('../data/Social_Network_Ads.csv')

df

 

# 나이와 연봉으로 분석해서, 물건을 구매할지 안할지를 분류하자

 

# 1) NaN 확인
df.isna().sum()

User ID            0
Gender             0
Age                0
EstimatedSalary    0
Purchased          0
dtype: int64

 

# 2) X,y 분류

y = df['Purchased']

y

0      0
1      0
2      0
3      0
4      0
      ..
395    1
396    1
397    1
398    0
399    1
Name: Purchased, Length: 400, dtype: int64

 

X = df.loc[ : , 'Age' : 'EstimatedSalary' ]

X

 

# 3) 인코딩이 필요하지만 이번 경우엔

# y의 경우 0,1로만 데이터가 구성되어 있어서 이미 레이블링 인코딩이 된상태로 봐도 무관
# X 데이터도 다 숫자로 되어있어서 인코딩 필요 없음
y.unique()

array([0, 1], dtype=int64)

 

# 4) 로지스틱 리그레이션은, 피쳐 스케일링을 하자. (린이어는 필요 없었음)
# X값은 두 컬럼의 숫자범위 차이가 크므로 스케일링 필요
# y는 0,1로만 구성되어 있으므로 필요 없음

 

from sklearn.preprocessing import StandardScaler, MinMaxScaler

# 둘중의 원하는것 사용

# X_scalerS = StandardScaler()

# 이번엔 정규화를 사용해 보자

 

# import한 인코딩 기능을 사용하기 위해 변수로 저장

X_scalerM = MinMaxScaler()

 

# 사용법
X = X_scalerM.fit_transform( X )

 

# 5) train, test 용으로 나누기

from sklearn.model_selection import train_test_split

 

# 트레인 테스트 범위 설정, X를 참고하여 y를 예측하고 train 75%, test는 25%를 사용해라.

train_test_split( X , y, test_size=0.25, random_state=1  )

[array([[0.04761905, 0.25185185],
        [0.66666667, 0.54074074],
        [0.30952381, 0.14074074],
        [0.69047619, 0.11111111],
        [0.4047619 , 0.25925926],
        [0.5       , 0.6       ],
        [0.35714286, 0.72592593],
        [0.73809524, 0.52592593],
        [0.64285714, 0.47407407],
        [0.61904762, 0.17777778],
        [0.54761905, 0.32592593],
        [0.57142857, 0.28148148],
        [0.4047619 , 0.42962963],
        [0.54761905, 0.42222222],
        [0.21428571, 0.6       ],
        [0.28571429, 0.74814815],
        [0.54761905, 0.27407407],
        [0.54761905, 0.27407407],
        [0.04761905, 0.4962963 ],
        [0.66666667, 0.19259259],
        [0.21428571, 0.11851852],
        [0.4047619 , 0.41481481],
        [0.73809524, 0.0962963 ],
        [0.4047619 , 0.56296296],
        [0.45238095, 0.44444444],
        [0.33333333, 0.75555556],
        [0.42857143, 0.44444444],
        [0.04761905, 0.52592593],
        [0.47619048, 0.25925926],
        [0.73809524, 0.15555556],
        [0.52380952, 0.37037037],
        [0.45238095, 0.45925926],
        [1.        , 0.22962963],
        [0.71428571, 0.91111111],
        [0.52380952, 0.41481481],
        [0.42857143, 0.35555556],
        [0.19047619, 0.48888889],
        [0.35714286, 0.11851852],
        [0.54761905, 0.26666667],
        [0.66666667, 0.43703704],
        [0.        , 0.4962963 ],
        [0.45238095, 0.13333333],
        [0.4047619 , 0.28148148],
        [0.23809524, 0.32592593],
        [0.45238095, 0.41481481],
        [0.        , 0.21481481],
        [0.47619048, 0.34074074],
        [0.4047619 , 0.37037037],
        [0.30952381, 0.37777778],
        [0.69047619, 0.07407407],
        [0.57142857, 0.65925926],
        [0.02380952, 0.02962963],
        [0.07142857, 0.42222222],
        [0.23809524, 0.12592593],
        [0.80952381, 1.        ],
        [0.83333333, 0.65925926],
        [0.16666667, 0.48148148],
        [0.19047619, 0.42222222],
        [0.19047619, 0.        ],
        [0.52380952, 0.94074074],
        [0.4047619 , 0.17037037],
        [0.92857143, 0.79259259],
        [0.54761905, 0.53333333],
        [0.14285714, 0.54814815],
        [0.45238095, 0.27407407],
        [0.4047619 , 0.68888889],
        [0.33333333, 0.62962963],
        [0.4047619 , 0.97777778],
        [0.02380952, 0.08148148],
        [0.66666667, 0.4962963 ],
        [0.69047619, 0.66666667],
        [0.64285714, 0.22222222],
        [0.47619048, 0.26666667],
        [0.45238095, 0.57777778],
        [0.26190476, 0.20740741],
        [0.54761905, 0.11111111],
        [0.52380952, 0.33333333],
        [0.21428571, 0.54074074],
        [0.21428571, 0.55555556],
        [0.5       , 0.41481481],
        [0.11904762, 0.0962963 ],
        [0.78571429, 0.97037037],
        [0.11904762, 0.35555556],
        [0.45238095, 0.48148148],
        [0.52380952, 0.23703704],
        [0.71428571, 0.1037037 ],
        [0.4047619 , 0.32592593],
        [0.19047619, 0.11111111],
        [0.5       , 0.67407407],
        [0.23809524, 0.51111111],
        [0.97619048, 0.45185185],
        [0.52380952, 0.31111111],
        [0.4047619 , 0.60740741],
        [0.19047619, 0.52592593],
        [0.54761905, 0.35555556],
        [0.26190476, 0.23703704],
        [0.19047619, 0.01481481],
        [0.95238095, 0.05925926],
        [0.28571429, 0.34814815],
        [0.16666667, 0.13333333],
        [0.97619048, 0.1037037 ],
        [0.23809524, 0.16296296],
        [0.5       , 0.45925926],
        [0.69047619, 0.68148148],
        [0.80952381, 0.91111111],
        [0.26190476, 0.20740741],
        [0.21428571, 0.9037037 ],
        [0.14285714, 0.2962963 ],
        [0.52380952, 0.42222222],
        [0.07142857, 0.00740741],
        [0.92857143, 0.08148148],
        [0.28571429, 0.88888889],
        [0.71428571, 0.77037037],
        [0.52380952, 0.31111111],
        [0.4047619 , 0.44444444],
        [0.23809524, 0.21481481],
        [0.45238095, 0.43703704],
        [0.4047619 , 0.08888889],
        [0.45238095, 0.47407407],
        [0.33333333, 0.77777778],
        [0.26190476, 0.44444444],
        [0.28571429, 0.01481481],
        [0.16666667, 0.47407407],
        [0.52380952, 0.68148148],
        [0.14285714, 0.02962963],
        [0.54761905, 0.42222222],
        [0.47619048, 0.34074074],
        [0.42857143, 0.95555556],
        [0.57142857, 0.36296296],
        [0.71428571, 0.13333333],
        [0.71428571, 0.11111111],
        [0.11904762, 0.03703704],
        [0.88095238, 0.85185185],
        [0.26190476, 0.98518519],
        [0.57142857, 0.37037037],
        [0.21428571, 0.28888889],
        [0.45238095, 0.2962963 ],
        [0.16666667, 0.05185185],
        [0.97619048, 0.94814815],
        [0.57142857, 0.28888889],
        [0.21428571, 0.01481481],
        [0.69047619, 0.25185185],
        [0.23809524, 0.32592593],
        [0.35714286, 0.4       ],
        [0.30952381, 0.39259259],
        [0.4047619 , 0.05925926],
        [0.4047619 , 0.05185185],
        [0.92857143, 0.13333333],
        [0.28571429, 0.68148148],
        [0.66666667, 0.05925926],
        [0.54761905, 0.33333333],
        [0.35714286, 0.19259259],
        [0.35714286, 0.33333333],
        [0.69047619, 0.23703704],
        [0.23809524, 0.2962963 ],
        [0.64285714, 0.12592593],
        [0.4047619 , 0.44444444],
        [0.97619048, 0.2       ],
        [0.69047619, 0.25925926],
        [0.69047619, 0.26666667],
        [0.4047619 , 0.47407407],
        [0.5       , 0.2       ],
        [0.52380952, 0.46666667],
        [0.11904762, 0.24444444],
        [0.5       , 0.44444444],
        [0.4047619 , 0.31111111],
        [0.88095238, 0.17777778],
        [0.30952381, 0.41481481],
        [0.23809524, 0.8       ],
        [0.57142857, 0.48148148],
        [0.52380952, 0.32592593],
        [0.71428571, 0.43703704],
        [0.07142857, 0.54074074],
        [0.83333333, 0.42222222],
        [0.21428571, 0.31851852],
        [0.4047619 , 0.23703704],
        [0.23809524, 0.54814815],
        [0.19047619, 0.48148148],
        [0.33333333, 0.75555556],
        [0.57142857, 0.44444444],
        [0.35714286, 0.99259259],
        [0.54761905, 0.48148148],
        [0.30952381, 0.31851852],
        [0.5       , 0.88148148],
        [0.21428571, 0.03703704],
        [0.26190476, 0.5037037 ],
        [0.02380952, 0.40740741],
        [0.02380952, 0.51851852],
        [0.26190476, 0.34074074],
        [0.5       , 0.88148148],
        [0.30952381, 0.45185185],
        [0.54761905, 0.42222222],
        [0.19047619, 0.48148148],
        [0.52380952, 0.34074074],
        [0.4047619 , 0.07407407],
        [0.71428571, 0.6       ],
        [0.57142857, 0.99259259],
        [0.23809524, 0.47407407],
        [0.78571429, 0.88148148],
        [0.35714286, 0.0962963 ],
        [0.57142857, 0.28888889],
        [0.64285714, 0.05185185],
        [0.45238095, 0.31111111],
        [0.38095238, 0.71851852],
        [0.4047619 , 0.17777778],
        [0.0952381 , 0.08888889],
        [0.4047619 , 0.42222222],
        [0.5       , 0.32592593],
        [0.04761905, 0.43703704],
        [0.66666667, 0.12592593],
        [0.19047619, 0.20740741],
        [0.26190476, 0.5037037 ],
        [0.88095238, 0.81481481],
        [0.45238095, 0.97037037],
        [0.64285714, 0.85925926],
        [0.35714286, 0.20740741],
        [0.54761905, 0.22222222],
        [0.57142857, 0.47407407],
        [0.45238095, 0.9037037 ],
        [0.14285714, 0.51111111],
        [0.33333333, 0.02222222],
        [0.9047619 , 0.65925926],
        [0.73809524, 0.17777778],
        [0.23809524, 0.51851852],
        [0.83333333, 0.94814815],
        [0.28571429, 0.54814815],
        [0.92857143, 0.33333333],
        [0.52380952, 0.44444444],
        [0.54761905, 0.47407407],
        [0.04761905, 0.4962963 ],
        [0.0952381 , 0.2962963 ],
        [0.4047619 , 0.54074074],
        [0.85714286, 0.40740741],
        [0.30952381, 0.        ],
        [0.76190476, 0.15555556],
        [0.57142857, 0.37037037],
        [0.38095238, 0.20740741],
        [0.57142857, 0.68888889],
        [0.85714286, 0.08148148],
        [0.02380952, 0.04444444],
        [0.42857143, 0.25925926],
        [0.45238095, 0.40740741],
        [0.42857143, 0.82222222],
        [0.69047619, 0.03703704],
        [0.28571429, 0.47407407],
        [0.97619048, 0.5037037 ],
        [0.26190476, 0.48148148],
        [0.5952381 , 0.71851852],
        [0.47619048, 0.48148148],
        [0.95238095, 0.95555556],
        [0.42857143, 0.81481481],
        [0.73809524, 0.0962963 ],
        [0.        , 0.27407407],
        [0.28571429, 0.        ],
        [0.97619048, 0.54074074],
        [0.21428571, 0.31111111],
        [0.47619048, 0.41481481],
        [0.30952381, 0.54814815],
        [0.69047619, 0.14074074],
        [0.30952381, 0.43703704],
        [0.45238095, 0.42222222],
        [0.52380952, 0.31111111],
        [0.97619048, 0.85185185],
        [0.73809524, 0.37037037],
        [0.71428571, 0.55555556],
        [0.66666667, 0.05185185],
        [0.85714286, 0.65925926],
        [0.4047619 , 0.03703704],
        [0.73809524, 0.93333333],
        [0.71428571, 0.19259259],
        [0.4047619 , 0.2962963 ],
        [0.42857143, 0.33333333],
        [0.33333333, 1.        ],
        [0.        , 0.39259259],
        [0.14285714, 0.2962963 ],
        [0.57142857, 0.55555556],
        [0.47619048, 0.32592593],
        [1.        , 0.68888889],
        [0.0952381 , 0.35555556],
        [0.14285714, 0.12592593],
        [0.66666667, 0.32592593],
        [0.71428571, 0.88148148],
        [0.54761905, 0.42222222],
        [0.76190476, 0.21481481],
        [0.47619048, 0.41481481],
        [0.14285714, 0.05925926],
        [0.4047619 , 0.34074074],
        [0.45238095, 0.48148148],
        [0.14285714, 0.08888889],
        [0.19047619, 0.51111111],
        [0.38095238, 0.07407407],
        [0.42857143, 0.28888889],
        [0.07142857, 0.39259259],
        [0.54761905, 0.41481481],
        [1.        , 0.2       ],
        [0.80952381, 0.55555556],
        [0.04761905, 0.05925926],
        [0.78571429, 0.05925926],
        [0.66666667, 0.47407407],
        [0.28571429, 0.25185185]]),
 array([[0.42857143, 0.13333333],
        [0.5       , 0.34074074],
        [0.42857143, 0.76296296],
        [0.5       , 0.79259259],
        [0.19047619, 0.76296296],
        [0.47619048, 0.37037037],
        [0.04761905, 0.15555556],
        [0.73809524, 0.54814815],
        [0.30952381, 0.02222222],
        [0.71428571, 0.93333333],
        [0.38095238, 0.42222222],
        [0.5       , 0.42962963],
        [0.4047619 , 0.42222222],
        [0.71428571, 0.85925926],
        [0.83333333, 0.4962963 ],
        [0.9047619 , 0.87407407],
        [1.        , 0.5037037 ],
        [0.21428571, 0.31851852],
        [0.23809524, 0.53333333],
        [1.        , 0.64444444],
        [0.52380952, 0.44444444],
        [0.76190476, 0.54074074],
        [0.61904762, 0.91851852],
        [0.69047619, 0.20740741],
        [0.64285714, 0.08148148],
        [0.19047619, 0.        ],
        [0.95238095, 0.23703704],
        [0.73809524, 0.43703704],
        [0.83333333, 0.14074074],
        [0.80952381, 0.73333333],
        [0.5       , 0.2       ],
        [0.02380952, 0.45185185],
        [0.        , 0.52592593],
        [0.92857143, 0.43703704],
        [0.21428571, 0.51111111],
        [0.28571429, 0.48148148],
        [0.0952381 , 0.02222222],
        [0.33333333, 0.52592593],
        [0.76190476, 0.03703704],
        [0.02380952, 0.07407407],
        [0.69047619, 0.95555556],
        [0.95238095, 0.63703704],
        [0.38095238, 0.74074074],
        [0.11904762, 0.37777778],
        [0.9047619 , 0.33333333],
        [0.30952381, 0.76296296],
        [0.71428571, 0.14814815],
        [0.69047619, 0.72592593],
        [0.5       , 0.47407407],
        [0.80952381, 0.17037037],
        [0.14285714, 0.31851852],
        [0.45238095, 0.28148148],
        [0.57142857, 0.48148148],
        [0.66666667, 0.0962963 ],
        [0.57142857, 0.42962963],
        [0.45238095, 0.34814815],
        [1.        , 0.2       ],
        [0.42857143, 0.27407407],
        [0.95238095, 0.59259259],
        [0.5952381 , 0.84444444],
        [0.21428571, 0.54814815],
        [0.11904762, 0.4962963 ],
        [0.47619048, 0.71851852],
        [0.4047619 , 0.25925926],
        [0.42857143, 0.62222222],
        [0.45238095, 0.95555556],
        [0.19047619, 0.14814815],
        [0.57142857, 0.40740741],
        [0.5952381 , 0.87407407],
        [0.47619048, 0.25925926],
        [0.66666667, 0.6       ],
        [0.4047619 , 0.21481481],
        [0.47619048, 0.72592593],
        [0.5       , 0.41481481],
        [0.19047619, 0.27407407],
        [0.85714286, 0.68888889],
        [0.35714286, 0.26666667],
        [0.19047619, 0.00740741],
        [0.28571429, 0.53333333],
        [0.4047619 , 0.33333333],
        [0.26190476, 0.0962963 ],
        [0.64285714, 0.05185185],
        [0.66666667, 0.75555556],
        [0.33333333, 0.02222222],
        [0.0952381 , 0.48888889],
        [0.16666667, 0.53333333],
        [0.71428571, 0.13333333],
        [0.4047619 , 0.31851852],
        [0.69047619, 0.05925926],
        [0.19047619, 0.12592593],
        [0.33333333, 0.88888889],
        [1.        , 0.14074074],
        [0.80952381, 0.04444444],
        [0.47619048, 0.2962963 ],
        [0.16666667, 0.55555556],
        [0.95238095, 0.17037037],
        [0.73809524, 0.54074074],
        [0.45238095, 0.46666667],
        [0.4047619 , 0.45925926],
        [0.38095238, 0.20740741]]),
 82     0
 367    1
 179    0
 27     1
 89     0
       ..
 255    1
 72     0
 396    1
 235    1
 37     0
 Name: Purchased, Length: 300, dtype: int64,
 398    0
 125    0
 328    1
 339    1
 172    0
       ..
 300    1
 277    1
 289    1
 260    0
 173    0
 Name: Purchased, Length: 100, dtype: int64]

 

# 실젝 학습과 예측을 하기 위해 변수로 저장

X_train, X_test, y_train, y_test = train_test_split( X , y, test_size=0.25, random_state=1  )

 

# 분류예측을 위한 인공지능이므로 linear 때와 import 다르니 주의!

from sklearn.linear_model import LogisticRegression

 

# 보통 분류문제 인공지능 명칭은 classfier로 명칭함
classifier = LogisticRegression()

 

# 학습
classifier.fit(X_train, y_train)

 

# 예측
classifier.predict( X_test )

array([0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1,
       1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
       1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1,
       0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0], dtype=int64)

 

# 예측은 하는데 0과 1사이 값을 그대로 보여줘라 (확률로) 두가지 모두 기억해 둘것!! (predict, predict_proba)
classifier.predict_proba(X_test)

array([[0.58140755, 0.41859245],
       [0.69539553, 0.30460447],
       [0.46748609, 0.53251391],
       [0.51296276, 0.48703724],
       [0.95013198, 0.04986802],
       [0.30087262, 0.69912738],
       [0.7530337 , 0.2469663 ],
       [0.53127652, 0.46872348],
       [0.78389265, 0.21610735],
       [0.64167867, 0.35832133],
       [0.63646856, 0.36353144],
       [0.78556309, 0.21443691],
       [0.05077763, 0.94922237],
       [0.3150107 , 0.6849893 ],
       [0.64878644, 0.35121356],
       [0.66908777, 0.33091223],
       [0.83541206, 0.16458794],
       [0.93353476, 0.06646524],
       [0.93720323, 0.06279677],
       [0.4125869 , 0.5874131 ],
       [0.91305277, 0.08694723],
       [0.9398901 , 0.0601099 ],
       [0.67962733, 0.32037267],
       [0.10364045, 0.89635955],
       [0.21694011, 0.78305989],
       [0.75382073, 0.24617927],
       [0.715914  , 0.284086  ],
       [0.36565234, 0.63434766],
       [0.54670903, 0.45329097],
       [0.83811072, 0.16188928],
       [0.8812389 , 0.1187611 ],
       [0.85826331, 0.14173669],
       [0.8611571 , 0.1388429 ],
       [0.53935259, 0.46064741],
       [0.19926929, 0.80073071],
       [0.84599512, 0.15400488],
       [0.35912721, 0.64087279],
       [0.30117245, 0.69882755],
       [0.71129766, 0.28870234],
       [0.89457311, 0.10542689],
       [0.69359781, 0.30640219],
       [0.58209398, 0.41790602],
       [0.96636639, 0.03363361],
       [0.88884895, 0.11115105],
       [0.74907504, 0.25092496],
       [0.88545712, 0.11454288],
       [0.158939  , 0.841061  ],
       [0.7530337 , 0.2469663 ],
       [0.659328  , 0.340672  ],
       [0.93176154, 0.06823846],
       [0.72388544, 0.27611456],
       [0.48438641, 0.51561359],
       [0.49920185, 0.50079815],
       [0.73722062, 0.26277938],
       [0.63385137, 0.36614863],
       [0.67158361, 0.32841639],
       [0.29761793, 0.70238207],
       [0.8784037 , 0.1215963 ],
       [0.4037186 , 0.5962814 ],
       [0.04043595, 0.95956405],
       [0.87960488, 0.12039512],
       [0.27512425, 0.72487575],
       [0.74480548, 0.25519452],
       [0.37420696, 0.62579304],
       [0.07319853, 0.92680147],
       [0.42737138, 0.57262862],
       [0.58140755, 0.41859245],
       [0.06792376, 0.93207624],
       [0.97747239, 0.02252761],
       [0.63155565, 0.36844435],
       [0.69149437, 0.30850563],
       [0.53092644, 0.46907356],
       [0.92703878, 0.07296122],
       [0.89457311, 0.10542689],
       [0.82870577, 0.17129423],
       [0.65996138, 0.34003862],
       [0.86548768, 0.13451232],
       [0.94979652, 0.05020348],
       [0.11248004, 0.88751996],
       [0.58997023, 0.41002977]])

 

# 실제값과 비교해 보기 위해서 예측한 값을 변수로 저장

y_pred = classifier.predict(X_test)

 

# 실제값을 df_test 변수로 저장

df_test = y_test.to_frame()

 

# 실제값에 예측값을 컬럼으로 삽입하여 비교

df_test['y_pred'] = y_pred

 

df_test

ㄴ 실제값과 예측값이 일치하면 인공지능이 맞춘거고 다르면 틀린것

 

# 인공지능을 만들었으면 예측 데이터를 가지고 '정확도'를 확인하는것이 중요하다.
# 정확도 계산은 0은 0으로 맞춘것, 0인데 1로 틀린거, 1인데 1로 맞춘것, 1인데 0으로 틀린것
# 전체케이스를 모두 더하고, 0을 0으로 맞춘것 1을 1로 맞춘것 을 더해서 => 맞춘것 / 전체데이터 = 정확도
# 이런 정확도를 계산하는것이 하단의 Confusion Matrix 분류 결과표

< Confusion Matrix >

 

두 개의 클래스로 분류하는 경우는 아래와 같다.

 

# 실제 코드로 불러와 보자.

 

from sklearn.metrics import confusion_matrix

 

confusion_matrix(y_test, y_pred)

array([[52,  6],
       [14, 28]], dtype=int64)

ㄴ 해석이 중요하다
# [0인데 0으로 맞춘것, 0인데 1로 맞춘것
#  1인데 0으로 맞춘것, 1인데 1로 맞춘것]

 

# 전체값을 봐보자.

cm = confusion_matrix(y_test, y_pred)

cm.sum()

100

 

# 정확도 계산 : accuracy
(52+28) / cm.sum()

0.8

ㄴ 80%

 

# 정확도 계산 라이브러리
from sklearn.metrics import accuracy_score

 

accuracy_score(y_test, y_pred)

0.8

 

## 분류 예측 결과를 종합적으로 보고 싶다면,

from sklearn.metrics import classification_report

 

# 메모리 상태를 찍은거기 때문에 해당 형태로 출력됨

classification_report(y_test, y_pred)

'              precision    recall  f1-score   support\n\n           0       0.79      0.90      0.84        58\n           1       0.82      0.67      0.74        42\n\n    accuracy                           0.80       100\n   macro avg       0.81      0.78      0.79       100\nweighted avg       0.80      0.80      0.80       100\n'

 

# 사람이 보기좋은 형태로 보려면 print로 호출
print(classification_report(y_test, y_pred))

 precision    recall  f1-score   support

           0       0.79      0.90      0.84        58
           1       0.82      0.67      0.74        42

    accuracy                           0.80       100
   macro avg       0.81      0.78      0.79       100
weighted avg       0.80      0.80      0.80       100

 

< 데이터 시각화>

# cm 결과값 숫자들을 색으로 표현하고 싶을때
import seaborn as sb

 

sb.heatmap(data = cm, cmap = 'RdPu' , annot= True)
plt.show()

 

# 이렇게까지만 작업을 수행하면됨

 

다음 게시글로 계속

반응형