< Logistic Regression 이란?? >
- Logistic Regression은 기계 학습과 통계에서 사용되는 통계적 분류 방법
- 주로 두 개 이상의 클래스 중 하나에 속하는 경우를 예측하는 데 사용된다.
- 주로 이진 분류(binary classification) 문제에 적용되며, 예를 들어 스팸 메일 여부 판별, 질병 진단 등 다양한 분야에서 활용된다.
- 이 방법은 선형 회귀(Linear Regression)와 비슷해 보이지만, 출력 값이 0과 1 사이의 확률값으로 제한되며, S자 형태의 로지스틱 함수를 사용하고, 이를 통해 예측값을 확률로 해석할 수 있다.
- 간단히 말하면, Logistic Regression은 주어진 입력 변수를 기반으로 특정 클래스에 속할 확률을 예측하는 모델
분류에 사용한다. (Classification)
예) 나이대별로 이메일을 클릭해서 열지 말지를 분류해 보자.
이메일 클릭을 할 사람과 안할 사람으로 분류할 것이다.
빨간점이 바로 데이터이며,
액션의 0 과 1 이 바로 레이블이다.
레이블이 있다는것은, 수퍼바이저드 러닝 이라는 뜻
이렇게 비슷하게 생긴 함수가 이미존재한다. 이름은 sigmoid function
따라서 리니어 리그레션 식을, y 값을 시그모이드에 대입해서, 일차방정식으로 만들면 다음과 같아진다.
위와 같은 식을 가진 regression 을, Logistic Regression이라 한다.
이제 우리는, 이를 가지고 두개의 클래스로 분류할 수 있다. ( 클릭을 한다, 안한다 두개로.)
확률로 나타낼 수 있게 되었다.
p는 확률값을 나타낸다.
20대는 클릭할 확률이 0.7%, 40대는 85%, 50대는 99.4%
이 확률값은, 위에서의 시그모이드 함수를 적용한 식을 통해 나온 값임을 기억한다.
최종 예측값은, 0.5를 기준으로 두개의 부류로 나눈다. 그 값은 0 과 1 이다.
< 이제 실제 예제문을 통해 Logistic Regression 실행해 보자 >
# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 새로운 데이터를 불러와 실행
df = pd.read_csv('../data/Social_Network_Ads.csv')
df
# 나이와 연봉으로 분석해서, 물건을 구매할지 안할지를 분류하자
# 1) NaN 확인
df.isna().sum()
User ID 0
Gender 0
Age 0
EstimatedSalary 0
Purchased 0
dtype: int64
# 2) X,y 분류
y = df['Purchased']
y
0 0
1 0
2 0
3 0
4 0
..
395 1
396 1
397 1
398 0
399 1
Name: Purchased, Length: 400, dtype: int64
X = df.loc[ : , 'Age' : 'EstimatedSalary' ]
X
# 3) 인코딩이 필요하지만 이번 경우엔
# y의 경우 0,1로만 데이터가 구성되어 있어서 이미 레이블링 인코딩이 된상태로 봐도 무관
# X 데이터도 다 숫자로 되어있어서 인코딩 필요 없음
y.unique()
array([0, 1], dtype=int64)
# 4) 로지스틱 리그레이션은, 피쳐 스케일링을 하자. (린이어는 필요 없었음)
# X값은 두 컬럼의 숫자범위 차이가 크므로 스케일링 필요
# y는 0,1로만 구성되어 있으므로 필요 없음
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# 둘중의 원하는것 사용
# X_scalerS = StandardScaler()
# 이번엔 정규화를 사용해 보자
# import한 인코딩 기능을 사용하기 위해 변수로 저장
X_scalerM = MinMaxScaler()
# 사용법
X = X_scalerM.fit_transform( X )
# 5) train, test 용으로 나누기
from sklearn.model_selection import train_test_split
# 트레인 테스트 범위 설정, X를 참고하여 y를 예측하고 train 75%, test는 25%를 사용해라.
train_test_split( X , y, test_size=0.25, random_state=1 )
[array([[0.04761905, 0.25185185],
[0.66666667, 0.54074074],
[0.30952381, 0.14074074],
[0.69047619, 0.11111111],
[0.4047619 , 0.25925926],
[0.5 , 0.6 ],
[0.35714286, 0.72592593],
[0.73809524, 0.52592593],
[0.64285714, 0.47407407],
[0.61904762, 0.17777778],
[0.54761905, 0.32592593],
[0.57142857, 0.28148148],
[0.4047619 , 0.42962963],
[0.54761905, 0.42222222],
[0.21428571, 0.6 ],
[0.28571429, 0.74814815],
[0.54761905, 0.27407407],
[0.54761905, 0.27407407],
[0.04761905, 0.4962963 ],
[0.66666667, 0.19259259],
[0.21428571, 0.11851852],
[0.4047619 , 0.41481481],
[0.73809524, 0.0962963 ],
[0.4047619 , 0.56296296],
[0.45238095, 0.44444444],
[0.33333333, 0.75555556],
[0.42857143, 0.44444444],
[0.04761905, 0.52592593],
[0.47619048, 0.25925926],
[0.73809524, 0.15555556],
[0.52380952, 0.37037037],
[0.45238095, 0.45925926],
[1. , 0.22962963],
[0.71428571, 0.91111111],
[0.52380952, 0.41481481],
[0.42857143, 0.35555556],
[0.19047619, 0.48888889],
[0.35714286, 0.11851852],
[0.54761905, 0.26666667],
[0.66666667, 0.43703704],
[0. , 0.4962963 ],
[0.45238095, 0.13333333],
[0.4047619 , 0.28148148],
[0.23809524, 0.32592593],
[0.45238095, 0.41481481],
[0. , 0.21481481],
[0.47619048, 0.34074074],
[0.4047619 , 0.37037037],
[0.30952381, 0.37777778],
[0.69047619, 0.07407407],
[0.57142857, 0.65925926],
[0.02380952, 0.02962963],
[0.07142857, 0.42222222],
[0.23809524, 0.12592593],
[0.80952381, 1. ],
[0.83333333, 0.65925926],
[0.16666667, 0.48148148],
[0.19047619, 0.42222222],
[0.19047619, 0. ],
[0.52380952, 0.94074074],
[0.4047619 , 0.17037037],
[0.92857143, 0.79259259],
[0.54761905, 0.53333333],
[0.14285714, 0.54814815],
[0.45238095, 0.27407407],
[0.4047619 , 0.68888889],
[0.33333333, 0.62962963],
[0.4047619 , 0.97777778],
[0.02380952, 0.08148148],
[0.66666667, 0.4962963 ],
[0.69047619, 0.66666667],
[0.64285714, 0.22222222],
[0.47619048, 0.26666667],
[0.45238095, 0.57777778],
[0.26190476, 0.20740741],
[0.54761905, 0.11111111],
[0.52380952, 0.33333333],
[0.21428571, 0.54074074],
[0.21428571, 0.55555556],
[0.5 , 0.41481481],
[0.11904762, 0.0962963 ],
[0.78571429, 0.97037037],
[0.11904762, 0.35555556],
[0.45238095, 0.48148148],
[0.52380952, 0.23703704],
[0.71428571, 0.1037037 ],
[0.4047619 , 0.32592593],
[0.19047619, 0.11111111],
[0.5 , 0.67407407],
[0.23809524, 0.51111111],
[0.97619048, 0.45185185],
[0.52380952, 0.31111111],
[0.4047619 , 0.60740741],
[0.19047619, 0.52592593],
[0.54761905, 0.35555556],
[0.26190476, 0.23703704],
[0.19047619, 0.01481481],
[0.95238095, 0.05925926],
[0.28571429, 0.34814815],
[0.16666667, 0.13333333],
[0.97619048, 0.1037037 ],
[0.23809524, 0.16296296],
[0.5 , 0.45925926],
[0.69047619, 0.68148148],
[0.80952381, 0.91111111],
[0.26190476, 0.20740741],
[0.21428571, 0.9037037 ],
[0.14285714, 0.2962963 ],
[0.52380952, 0.42222222],
[0.07142857, 0.00740741],
[0.92857143, 0.08148148],
[0.28571429, 0.88888889],
[0.71428571, 0.77037037],
[0.52380952, 0.31111111],
[0.4047619 , 0.44444444],
[0.23809524, 0.21481481],
[0.45238095, 0.43703704],
[0.4047619 , 0.08888889],
[0.45238095, 0.47407407],
[0.33333333, 0.77777778],
[0.26190476, 0.44444444],
[0.28571429, 0.01481481],
[0.16666667, 0.47407407],
[0.52380952, 0.68148148],
[0.14285714, 0.02962963],
[0.54761905, 0.42222222],
[0.47619048, 0.34074074],
[0.42857143, 0.95555556],
[0.57142857, 0.36296296],
[0.71428571, 0.13333333],
[0.71428571, 0.11111111],
[0.11904762, 0.03703704],
[0.88095238, 0.85185185],
[0.26190476, 0.98518519],
[0.57142857, 0.37037037],
[0.21428571, 0.28888889],
[0.45238095, 0.2962963 ],
[0.16666667, 0.05185185],
[0.97619048, 0.94814815],
[0.57142857, 0.28888889],
[0.21428571, 0.01481481],
[0.69047619, 0.25185185],
[0.23809524, 0.32592593],
[0.35714286, 0.4 ],
[0.30952381, 0.39259259],
[0.4047619 , 0.05925926],
[0.4047619 , 0.05185185],
[0.92857143, 0.13333333],
[0.28571429, 0.68148148],
[0.66666667, 0.05925926],
[0.54761905, 0.33333333],
[0.35714286, 0.19259259],
[0.35714286, 0.33333333],
[0.69047619, 0.23703704],
[0.23809524, 0.2962963 ],
[0.64285714, 0.12592593],
[0.4047619 , 0.44444444],
[0.97619048, 0.2 ],
[0.69047619, 0.25925926],
[0.69047619, 0.26666667],
[0.4047619 , 0.47407407],
[0.5 , 0.2 ],
[0.52380952, 0.46666667],
[0.11904762, 0.24444444],
[0.5 , 0.44444444],
[0.4047619 , 0.31111111],
[0.88095238, 0.17777778],
[0.30952381, 0.41481481],
[0.23809524, 0.8 ],
[0.57142857, 0.48148148],
[0.52380952, 0.32592593],
[0.71428571, 0.43703704],
[0.07142857, 0.54074074],
[0.83333333, 0.42222222],
[0.21428571, 0.31851852],
[0.4047619 , 0.23703704],
[0.23809524, 0.54814815],
[0.19047619, 0.48148148],
[0.33333333, 0.75555556],
[0.57142857, 0.44444444],
[0.35714286, 0.99259259],
[0.54761905, 0.48148148],
[0.30952381, 0.31851852],
[0.5 , 0.88148148],
[0.21428571, 0.03703704],
[0.26190476, 0.5037037 ],
[0.02380952, 0.40740741],
[0.02380952, 0.51851852],
[0.26190476, 0.34074074],
[0.5 , 0.88148148],
[0.30952381, 0.45185185],
[0.54761905, 0.42222222],
[0.19047619, 0.48148148],
[0.52380952, 0.34074074],
[0.4047619 , 0.07407407],
[0.71428571, 0.6 ],
[0.57142857, 0.99259259],
[0.23809524, 0.47407407],
[0.78571429, 0.88148148],
[0.35714286, 0.0962963 ],
[0.57142857, 0.28888889],
[0.64285714, 0.05185185],
[0.45238095, 0.31111111],
[0.38095238, 0.71851852],
[0.4047619 , 0.17777778],
[0.0952381 , 0.08888889],
[0.4047619 , 0.42222222],
[0.5 , 0.32592593],
[0.04761905, 0.43703704],
[0.66666667, 0.12592593],
[0.19047619, 0.20740741],
[0.26190476, 0.5037037 ],
[0.88095238, 0.81481481],
[0.45238095, 0.97037037],
[0.64285714, 0.85925926],
[0.35714286, 0.20740741],
[0.54761905, 0.22222222],
[0.57142857, 0.47407407],
[0.45238095, 0.9037037 ],
[0.14285714, 0.51111111],
[0.33333333, 0.02222222],
[0.9047619 , 0.65925926],
[0.73809524, 0.17777778],
[0.23809524, 0.51851852],
[0.83333333, 0.94814815],
[0.28571429, 0.54814815],
[0.92857143, 0.33333333],
[0.52380952, 0.44444444],
[0.54761905, 0.47407407],
[0.04761905, 0.4962963 ],
[0.0952381 , 0.2962963 ],
[0.4047619 , 0.54074074],
[0.85714286, 0.40740741],
[0.30952381, 0. ],
[0.76190476, 0.15555556],
[0.57142857, 0.37037037],
[0.38095238, 0.20740741],
[0.57142857, 0.68888889],
[0.85714286, 0.08148148],
[0.02380952, 0.04444444],
[0.42857143, 0.25925926],
[0.45238095, 0.40740741],
[0.42857143, 0.82222222],
[0.69047619, 0.03703704],
[0.28571429, 0.47407407],
[0.97619048, 0.5037037 ],
[0.26190476, 0.48148148],
[0.5952381 , 0.71851852],
[0.47619048, 0.48148148],
[0.95238095, 0.95555556],
[0.42857143, 0.81481481],
[0.73809524, 0.0962963 ],
[0. , 0.27407407],
[0.28571429, 0. ],
[0.97619048, 0.54074074],
[0.21428571, 0.31111111],
[0.47619048, 0.41481481],
[0.30952381, 0.54814815],
[0.69047619, 0.14074074],
[0.30952381, 0.43703704],
[0.45238095, 0.42222222],
[0.52380952, 0.31111111],
[0.97619048, 0.85185185],
[0.73809524, 0.37037037],
[0.71428571, 0.55555556],
[0.66666667, 0.05185185],
[0.85714286, 0.65925926],
[0.4047619 , 0.03703704],
[0.73809524, 0.93333333],
[0.71428571, 0.19259259],
[0.4047619 , 0.2962963 ],
[0.42857143, 0.33333333],
[0.33333333, 1. ],
[0. , 0.39259259],
[0.14285714, 0.2962963 ],
[0.57142857, 0.55555556],
[0.47619048, 0.32592593],
[1. , 0.68888889],
[0.0952381 , 0.35555556],
[0.14285714, 0.12592593],
[0.66666667, 0.32592593],
[0.71428571, 0.88148148],
[0.54761905, 0.42222222],
[0.76190476, 0.21481481],
[0.47619048, 0.41481481],
[0.14285714, 0.05925926],
[0.4047619 , 0.34074074],
[0.45238095, 0.48148148],
[0.14285714, 0.08888889],
[0.19047619, 0.51111111],
[0.38095238, 0.07407407],
[0.42857143, 0.28888889],
[0.07142857, 0.39259259],
[0.54761905, 0.41481481],
[1. , 0.2 ],
[0.80952381, 0.55555556],
[0.04761905, 0.05925926],
[0.78571429, 0.05925926],
[0.66666667, 0.47407407],
[0.28571429, 0.25185185]]),
array([[0.42857143, 0.13333333],
[0.5 , 0.34074074],
[0.42857143, 0.76296296],
[0.5 , 0.79259259],
[0.19047619, 0.76296296],
[0.47619048, 0.37037037],
[0.04761905, 0.15555556],
[0.73809524, 0.54814815],
[0.30952381, 0.02222222],
[0.71428571, 0.93333333],
[0.38095238, 0.42222222],
[0.5 , 0.42962963],
[0.4047619 , 0.42222222],
[0.71428571, 0.85925926],
[0.83333333, 0.4962963 ],
[0.9047619 , 0.87407407],
[1. , 0.5037037 ],
[0.21428571, 0.31851852],
[0.23809524, 0.53333333],
[1. , 0.64444444],
[0.52380952, 0.44444444],
[0.76190476, 0.54074074],
[0.61904762, 0.91851852],
[0.69047619, 0.20740741],
[0.64285714, 0.08148148],
[0.19047619, 0. ],
[0.95238095, 0.23703704],
[0.73809524, 0.43703704],
[0.83333333, 0.14074074],
[0.80952381, 0.73333333],
[0.5 , 0.2 ],
[0.02380952, 0.45185185],
[0. , 0.52592593],
[0.92857143, 0.43703704],
[0.21428571, 0.51111111],
[0.28571429, 0.48148148],
[0.0952381 , 0.02222222],
[0.33333333, 0.52592593],
[0.76190476, 0.03703704],
[0.02380952, 0.07407407],
[0.69047619, 0.95555556],
[0.95238095, 0.63703704],
[0.38095238, 0.74074074],
[0.11904762, 0.37777778],
[0.9047619 , 0.33333333],
[0.30952381, 0.76296296],
[0.71428571, 0.14814815],
[0.69047619, 0.72592593],
[0.5 , 0.47407407],
[0.80952381, 0.17037037],
[0.14285714, 0.31851852],
[0.45238095, 0.28148148],
[0.57142857, 0.48148148],
[0.66666667, 0.0962963 ],
[0.57142857, 0.42962963],
[0.45238095, 0.34814815],
[1. , 0.2 ],
[0.42857143, 0.27407407],
[0.95238095, 0.59259259],
[0.5952381 , 0.84444444],
[0.21428571, 0.54814815],
[0.11904762, 0.4962963 ],
[0.47619048, 0.71851852],
[0.4047619 , 0.25925926],
[0.42857143, 0.62222222],
[0.45238095, 0.95555556],
[0.19047619, 0.14814815],
[0.57142857, 0.40740741],
[0.5952381 , 0.87407407],
[0.47619048, 0.25925926],
[0.66666667, 0.6 ],
[0.4047619 , 0.21481481],
[0.47619048, 0.72592593],
[0.5 , 0.41481481],
[0.19047619, 0.27407407],
[0.85714286, 0.68888889],
[0.35714286, 0.26666667],
[0.19047619, 0.00740741],
[0.28571429, 0.53333333],
[0.4047619 , 0.33333333],
[0.26190476, 0.0962963 ],
[0.64285714, 0.05185185],
[0.66666667, 0.75555556],
[0.33333333, 0.02222222],
[0.0952381 , 0.48888889],
[0.16666667, 0.53333333],
[0.71428571, 0.13333333],
[0.4047619 , 0.31851852],
[0.69047619, 0.05925926],
[0.19047619, 0.12592593],
[0.33333333, 0.88888889],
[1. , 0.14074074],
[0.80952381, 0.04444444],
[0.47619048, 0.2962963 ],
[0.16666667, 0.55555556],
[0.95238095, 0.17037037],
[0.73809524, 0.54074074],
[0.45238095, 0.46666667],
[0.4047619 , 0.45925926],
[0.38095238, 0.20740741]]),
82 0
367 1
179 0
27 1
89 0
..
255 1
72 0
396 1
235 1
37 0
Name: Purchased, Length: 300, dtype: int64,
398 0
125 0
328 1
339 1
172 0
..
300 1
277 1
289 1
260 0
173 0
Name: Purchased, Length: 100, dtype: int64]
# 실젝 학습과 예측을 하기 위해 변수로 저장
X_train, X_test, y_train, y_test = train_test_split( X , y, test_size=0.25, random_state=1 )
# 분류예측을 위한 인공지능이므로 linear 때와 import 다르니 주의!
from sklearn.linear_model import LogisticRegression
# 보통 분류문제 인공지능 명칭은 classfier로 명칭함
classifier = LogisticRegression()
# 학습
classifier.fit(X_train, y_train)
# 예측
classifier.predict( X_test )
array([0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1,
1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1,
0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0], dtype=int64)
# 예측은 하는데 0과 1사이 값을 그대로 보여줘라 (확률로) 두가지 모두 기억해 둘것!! (predict, predict_proba)
classifier.predict_proba(X_test)
array([[0.58140755, 0.41859245],
[0.69539553, 0.30460447],
[0.46748609, 0.53251391],
[0.51296276, 0.48703724],
[0.95013198, 0.04986802],
[0.30087262, 0.69912738],
[0.7530337 , 0.2469663 ],
[0.53127652, 0.46872348],
[0.78389265, 0.21610735],
[0.64167867, 0.35832133],
[0.63646856, 0.36353144],
[0.78556309, 0.21443691],
[0.05077763, 0.94922237],
[0.3150107 , 0.6849893 ],
[0.64878644, 0.35121356],
[0.66908777, 0.33091223],
[0.83541206, 0.16458794],
[0.93353476, 0.06646524],
[0.93720323, 0.06279677],
[0.4125869 , 0.5874131 ],
[0.91305277, 0.08694723],
[0.9398901 , 0.0601099 ],
[0.67962733, 0.32037267],
[0.10364045, 0.89635955],
[0.21694011, 0.78305989],
[0.75382073, 0.24617927],
[0.715914 , 0.284086 ],
[0.36565234, 0.63434766],
[0.54670903, 0.45329097],
[0.83811072, 0.16188928],
[0.8812389 , 0.1187611 ],
[0.85826331, 0.14173669],
[0.8611571 , 0.1388429 ],
[0.53935259, 0.46064741],
[0.19926929, 0.80073071],
[0.84599512, 0.15400488],
[0.35912721, 0.64087279],
[0.30117245, 0.69882755],
[0.71129766, 0.28870234],
[0.89457311, 0.10542689],
[0.69359781, 0.30640219],
[0.58209398, 0.41790602],
[0.96636639, 0.03363361],
[0.88884895, 0.11115105],
[0.74907504, 0.25092496],
[0.88545712, 0.11454288],
[0.158939 , 0.841061 ],
[0.7530337 , 0.2469663 ],
[0.659328 , 0.340672 ],
[0.93176154, 0.06823846],
[0.72388544, 0.27611456],
[0.48438641, 0.51561359],
[0.49920185, 0.50079815],
[0.73722062, 0.26277938],
[0.63385137, 0.36614863],
[0.67158361, 0.32841639],
[0.29761793, 0.70238207],
[0.8784037 , 0.1215963 ],
[0.4037186 , 0.5962814 ],
[0.04043595, 0.95956405],
[0.87960488, 0.12039512],
[0.27512425, 0.72487575],
[0.74480548, 0.25519452],
[0.37420696, 0.62579304],
[0.07319853, 0.92680147],
[0.42737138, 0.57262862],
[0.58140755, 0.41859245],
[0.06792376, 0.93207624],
[0.97747239, 0.02252761],
[0.63155565, 0.36844435],
[0.69149437, 0.30850563],
[0.53092644, 0.46907356],
[0.92703878, 0.07296122],
[0.89457311, 0.10542689],
[0.82870577, 0.17129423],
[0.65996138, 0.34003862],
[0.86548768, 0.13451232],
[0.94979652, 0.05020348],
[0.11248004, 0.88751996],
[0.58997023, 0.41002977]])
# 실제값과 비교해 보기 위해서 예측한 값을 변수로 저장
y_pred = classifier.predict(X_test)
# 실제값을 df_test 변수로 저장
df_test = y_test.to_frame()
# 실제값에 예측값을 컬럼으로 삽입하여 비교
df_test['y_pred'] = y_pred
df_test
ㄴ 실제값과 예측값이 일치하면 인공지능이 맞춘거고 다르면 틀린것
# 인공지능을 만들었으면 예측 데이터를 가지고 '정확도'를 확인하는것이 중요하다.
# 정확도 계산은 0은 0으로 맞춘것, 0인데 1로 틀린거, 1인데 1로 맞춘것, 1인데 0으로 틀린것
# 전체케이스를 모두 더하고, 0을 0으로 맞춘것 1을 1로 맞춘것 을 더해서 => 맞춘것 / 전체데이터 = 정확도
# 이런 정확도를 계산하는것이 하단의 Confusion Matrix 분류 결과표
< Confusion Matrix >
두 개의 클래스로 분류하는 경우는 아래와 같다.
# 실제 코드로 불러와 보자.
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_pred)
array([[52, 6],
[14, 28]], dtype=int64)
ㄴ 해석이 중요하다
# [0인데 0으로 맞춘것, 0인데 1로 맞춘것
# 1인데 0으로 맞춘것, 1인데 1로 맞춘것]
# 전체값을 봐보자.
cm = confusion_matrix(y_test, y_pred)
cm.sum()
100
# 정확도 계산 : accuracy
(52+28) / cm.sum()
0.8
ㄴ 80%
# 정확도 계산 라이브러리
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)
0.8
## 분류 예측 결과를 종합적으로 보고 싶다면,
from sklearn.metrics import classification_report
# 메모리 상태를 찍은거기 때문에 해당 형태로 출력됨
classification_report(y_test, y_pred)
' precision recall f1-score support\n\n 0 0.79 0.90 0.84 58\n 1 0.82 0.67 0.74 42\n\n accuracy 0.80 100\n macro avg 0.81 0.78 0.79 100\nweighted avg 0.80 0.80 0.80 100\n'
# 사람이 보기좋은 형태로 보려면 print로 호출
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 0.79 0.90 0.84 58
1 0.82 0.67 0.74 42
accuracy 0.80 100
macro avg 0.81 0.78 0.79 100
weighted avg 0.80 0.80 0.80 100
< 데이터 시각화>
# cm 결과값 숫자들을 색으로 표현하고 싶을때
import seaborn as sb
sb.heatmap(data = cm, cmap = 'RdPu' , annot= True)
plt.show()
# 이렇게까지만 작업을 수행하면됨