< Multiple Linear Regression 이란? >
다중 선형 회귀(Multiple Linear Regression)는 두 개 이상의 독립 변수가 종속 변수에 미치는 영향을 분석하는 회귀 분석 방법이다.
단순 선형 회귀(Linear Regression)와 달리, 다중 선형 회귀는 종속 변수와 관련된 여러 개의 독립 변수를 사용하여 예측 모델을 구축한다.
아래처럼, 여러개의 features 를 기반으로, 수익을 예측하려 한다.
위와 같이, 여러개의 변수들을 통해, 수익과의 관계를 분석하고,
이를 통해, 새로운 데이터가 들어왔을 때, 수익이 어떻게 될 지를 예측하고자 한다.
아래는 하나의 변수일때와, 여러개의 변수가 있을때의 leaner regression 을 나타낸다.
2차원에서는 선 이지만, 3차원에서는 평면이 된다.
자, 이제, 오차가 가장 적을때의 b 값들을 찾아보자.
1. 먼저 식을 세운다. 이때 숫자가 아닌값은 어떻게 처리해야 할까?
catergorical 로 바꿔주면 된다.
# Profit 수익을 예측하려 한다. 이것이 디펜더블 베리어블, 나머지는 인디펜더블 베리어블
# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 데이터 불러오기
# 50_Startups.csv 데이터를 읽으세요.
# 각각의 피쳐를 분석하여, 어떤 신생 회사의 데이터가 있으면, 그 회사가 얼마의 수익을 낼 지 예측합니다. (투자를 해야 할지 말아야 할지)
df = pd.read_csv('../data/50_Startups.csv')
...
...
# 1) NaN 처리
df.isna().sum()
R&D Spend 0
Administration 0
Marketing Spend 0
State 0
Profit 0
dtype: int64
# 2) 없으므로 X,y 로 분리
# 예상을해야 하는 수익을 y 열
y = df['Profit']
# 나머지 컬럼은 모두 X로 지정
X = df.iloc[ : , 0 : 3+1 ]
# 3) 문자열 데이터는 숫자로 인코딩
# 데이터를 확인해 보면 X에 State에만 문자열이 있기때문에 X만 인코딩 필요
X['State'].nunique()
3
# 문자열 눈으로 확인 숫자로 인코딩할경우 스펠링 순으로 매겨지기 때문에 꼭 눈으로 확인 필요!!
sorted(df['State'].unique() )
['California', 'Florida', 'New York']
# 갯수가 3개 이상이므로 원-핫 인코딩 필요
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
# 원-핫 인코딩의 경우 ColumnTransformer 를 사용해 인코딩을 해주어야 함.
# ([('임의의 변수명(보통 encoder를 사용', OneHotEncoder(), [인코딩할 컬럼의 인데스 숫자 여기선 3열 이므로 3])], remainder= 'passthrough' => 해당 열을 제외하고는 모두 통과시킨다는 의미로 꼭 적어주어야함)
ct = ColumnTransformer( [('encoder' , OneHotEncoder(), [ 3 ])] , remainder= 'passthrough' )
# 원-핫 인코딩된 컬럼이 항상 맨 왼쪽에 위치하게 된다.
# 즉 State 가 원-핫 인코딩이 되어서 세개의 컬럼으로 분할됨
# array([[ 캘리포니아 , 플로리다 , 뉴욕 , R&D , Admin , Marketing 순으로 출력된것]])
# 원한인코더로 만든 변수 ct를 이용하여 인코딩 진행
X = ct.fit_transform(X)
X
array([[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.6534920e+05,
1.3689780e+05, 4.7178410e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.6259770e+05,
1.5137759e+05, 4.4389853e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.5344151e+05,
1.0114555e+05, 4.0793454e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.4437241e+05,
1.1867185e+05, 3.8319962e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.4210734e+05,
9.1391770e+04, 3.6616842e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.3187690e+05,
9.9814710e+04, 3.6286136e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3461546e+05,
1.4719887e+05, 1.2771682e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3029813e+05,
1.4553006e+05, 3.2387668e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.2054252e+05,
1.4871895e+05, 3.1161329e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.2333488e+05,
1.0867917e+05, 3.0498162e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0191308e+05,
1.1059411e+05, 2.2916095e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0067196e+05,
9.1790610e+04, 2.4974455e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.3863750e+04,
1.2732038e+05, 2.4983944e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.1992390e+04,
1.3549507e+05, 2.5266493e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.1994324e+05,
1.5654742e+05, 2.5651292e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.1452361e+05,
1.2261684e+05, 2.6177623e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.8013110e+04,
1.2159755e+05, 2.6434606e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 9.4657160e+04,
1.4507758e+05, 2.8257431e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.1749160e+04,
1.1417579e+05, 2.9491957e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 8.6419700e+04,
1.5351411e+05, 0.0000000e+00],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.6253860e+04,
1.1386730e+05, 2.9866447e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.8389470e+04,
1.5377343e+05, 2.9973729e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.3994560e+04,
1.2278275e+05, 3.0331926e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.7532530e+04,
1.0575103e+05, 3.0476873e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.7044010e+04,
9.9281340e+04, 1.4057481e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4664710e+04,
1.3955316e+05, 1.3796262e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.5328870e+04,
1.4413598e+05, 1.3405007e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.2107600e+04,
1.2786455e+05, 3.5318381e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.6051520e+04,
1.8264556e+05, 1.1814820e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.5605480e+04,
1.5303206e+05, 1.0713838e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.1994480e+04,
1.1564128e+05, 9.1131240e+04],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.1136380e+04,
1.5270192e+05, 8.8218230e+04],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.3408860e+04,
1.2921961e+05, 4.6085250e+04],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 5.5493950e+04,
1.0305749e+05, 2.1463481e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.6426070e+04,
1.5769392e+05, 2.1079767e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 4.6014020e+04,
8.5047440e+04, 2.0551764e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.8663760e+04,
1.2705621e+05, 2.0112682e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4069950e+04,
5.1283140e+04, 1.9702942e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 2.0229590e+04,
6.5947930e+04, 1.8526510e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8558510e+04,
8.2982090e+04, 1.7499930e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.8754330e+04,
1.1854605e+05, 1.7279567e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.7892920e+04,
8.4710770e+04, 1.6447071e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3640930e+04,
9.6189630e+04, 1.4800111e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.5505730e+04,
1.2738230e+05, 3.5534170e+04],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.2177740e+04,
1.5480614e+05, 2.8334720e+04],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.0002300e+03,
1.2415304e+05, 1.9039300e+03],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3154600e+03,
1.1581621e+05, 2.9711446e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
1.3542692e+05, 0.0000000e+00],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 5.4205000e+02,
5.1743150e+04, 0.0000000e+00],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
1.1698380e+05, 4.5173060e+04]])
# 4) 피쳐 스케일링이 필요하지만 Linear Regression 은 피쳐스케일링이 필요없어 패스
# 5) training / test 로 나눈다.
from sklearn.model_selection import train_test_split
train_test_split(X, y, test_size=0.2, random_state=65)
[array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4664710e+04,
1.3955316e+05, 1.3796262e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.3994560e+04,
1.2278275e+05, 3.0331926e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.1992390e+04,
1.3549507e+05, 2.5266493e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
1.1698380e+05, 4.5173060e+04],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.2107600e+04,
1.2786455e+05, 3.5318381e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.6426070e+04,
1.5769392e+05, 2.1079767e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.3863750e+04,
1.2732038e+05, 2.4983944e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.8013110e+04,
1.2159755e+05, 2.6434606e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.3187690e+05,
9.9814710e+04, 3.6286136e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3640930e+04,
9.6189630e+04, 1.4800111e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.1136380e+04,
1.5270192e+05, 8.8218230e+04],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.4437241e+05,
1.1867185e+05, 3.8319962e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.5344151e+05,
1.0114555e+05, 4.0793454e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.6534920e+05,
1.3689780e+05, 4.7178410e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0067196e+05,
9.1790610e+04, 2.4974455e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.4210734e+05,
9.1391770e+04, 3.6616842e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.2054252e+05,
1.4871895e+05, 3.1161329e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 8.6419700e+04,
1.5351411e+05, 0.0000000e+00],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.7044010e+04,
9.9281340e+04, 1.4057481e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 4.6014020e+04,
8.5047440e+04, 2.0551764e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.1994480e+04,
1.1564128e+05, 9.1131240e+04],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.5505730e+04,
1.2738230e+05, 3.5534170e+04],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8558510e+04,
8.2982090e+04, 1.7499930e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.8663760e+04,
1.2705621e+05, 2.0112682e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.2177740e+04,
1.5480614e+05, 2.8334720e+04],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0191308e+05,
1.1059411e+05, 2.2916095e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.5328870e+04,
1.4413598e+05, 1.3405007e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.2333488e+05,
1.0867917e+05, 3.0498162e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.1452361e+05,
1.2261684e+05, 2.6177623e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 5.5493950e+04,
1.0305749e+05, 2.1463481e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4069950e+04,
5.1283140e+04, 1.9702942e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 5.4205000e+02,
5.1743150e+04, 0.0000000e+00],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.6259770e+05,
1.5137759e+05, 4.4389853e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 2.0229590e+04,
6.5947930e+04, 1.8526510e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3461546e+05,
1.4719887e+05, 1.2771682e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.5605480e+04,
1.5303206e+05, 1.0713838e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3029813e+05,
1.4553006e+05, 3.2387668e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.8389470e+04,
1.5377343e+05, 2.9973729e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.8754330e+04,
1.1854605e+05, 1.7279567e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3154600e+03,
1.1581621e+05, 2.9711446e+05]]),
array([[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.7892920e+04,
8.4710770e+04, 1.6447071e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 9.4657160e+04,
1.4507758e+05, 2.8257431e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.6253860e+04,
1.1386730e+05, 2.9866447e+05],
[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.0002300e+03,
1.2415304e+05, 1.9039300e+03],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.7532530e+04,
1.0575103e+05, 3.0476873e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.1994324e+05,
1.5654742e+05, 2.5651292e+05],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.6051520e+04,
1.8264556e+05, 1.1814820e+05],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
1.3542692e+05, 0.0000000e+00],
[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.3408860e+04,
1.2921961e+05, 4.6085250e+04],
[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.1749160e+04,
1.1417579e+05, 2.9491957e+05]]),
25 107404.34
22 110352.25
13 134307.35
49 14681.40
27 105008.31
34 96712.80
12 141585.52
16 126992.93
5 156991.12
42 71498.49
31 97483.56
3 182901.99
2 191050.39
0 192261.83
11 144259.40
4 166187.94
8 152211.77
19 122776.86
24 108552.04
35 96479.51
30 99937.59
43 69758.98
39 81005.76
36 90708.19
44 65200.33
10 146121.95
26 105733.54
9 149759.96
15 129917.04
33 96778.92
37 89949.14
48 35673.41
1 191792.06
38 81229.06
6 156122.51
29 101004.64
7 155752.60
21 111313.02
40 78239.91
46 49490.75
Name: Profit, dtype: float64,
41 77798.83
17 125370.37
20 118474.03
45 64926.08
23 108733.99
14 132602.65
28 103282.38
47 42559.73
32 97427.84
18 124266.90
Name: Profit, dtype: float64]
# train/test 변수로 나누어서 메모리에 업로드
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=65)
# 깡통 인공지능 만들기
from sklearn.linear_model import LinearRegression
# 학습하는 함수는 .fil()
regressor.fit(X_train, y_train)
# 계수
regressor.coef_
array([ 8.29736108e+00, 1.35646415e+03, -1.36476151e+03, 8.24637324e-01,
-1.12195852e-02, 2.80920611e-02])
# 상수
regressor.intercept_
46989.22920268966
# 검증(.predict) 하고 MSE를 구해봐라!!
regressor.predict(X_test)
array([ 75017.13778857, 129992.67932128, 116991.86227872, 45109.83439244,
111410.63726075, 152704.94532349, 104084.06108572, 45478.09269869,
99131.67843245, 131009.36870589])
y_pred = regressor.predict(X_test)
# 이제 실제값 - 예측값으로 오차 범위를 구해야 한다.
y_test - y_pred
41 2781.692211
17 -4622.309321
20 1482.167721
45 19816.245608
23 -2676.647261
14 -20102.295323
28 -801.681086
47 -2918.362699
32 -1703.838432
18 -6742.468706
Name: Profit, dtype: float64
# 성능측정 = 오차를 제곱해서, 부호를 먼저 없앤 후에 평균을 구한다.
error = y_test - y_pred
error ** 2
41 7.737812e+06
17 2.136574e+07
20 2.196821e+06
45 3.926836e+08
23 7.164441e+06
14 4.041023e+08
28 6.426926e+05
47 8.516841e+06
32 2.903065e+06
18 4.546088e+07
Name: Profit, dtype: float64
# 해당 값이 MSE (Mean squared error)
(error ** 2).mean()
89277416.70428361
# RMSE
# 수치가 너무 크니까 MSE에 루트를 씌우자 => 그냥 참고정도로 알고있으면 되고 MSE 만 알고있어도됨
# 성능평가할때 사용됨
np.sqrt((error ** 2).mean())
9448.672748290292
###즉 regression 평가에서는 이 인공지능이 잘동작되는지를 평가하는 기준이 MSE인 것!! ###
< 시각화도 해보자 >
# 차트로 그리기 위해서 실제값과 예측값을 합치기 위해 판다스 시리즈를 데이터 프레임으로 전환
# 시리즈에 열을 붙이면 시리즈(1차원)가 아니기때문에 붙일수가 없다. 그러므로 데이터 프레임으로 전환하는것
df_test = y_test.to_frame()
# 예측값을 실제값에 컬럼으로 삽입
df_test['y_pred'] = y_pred
# 인덱스가 순서대로 되어있지 않기때문에 순서대로 정렬하여 그려야 이쁘게 그려짐
# drop 없이 실행하면 인덱스 컬럼까지 생기기 때문에 인덱스 컬럼은 가져오지 않게 drop=True를 사용
df_test.reset_index(drop=True, inplace=True)
df_test.plot()
plt.show()
# 차트를 bar 형태로 보고 싶을때
# 이미지 파일을 저장하고 싶으면 plt.savefig()
# 저장 명령어는 plt.show() 전에 적어야 그대로 저장된다 순서대로 실행됨
df_test.plot(kind='bar')
plt.savefig('test.jpg')
plt.show()
< 실제 예제문을 통해 실제 업무에 적용해본다고 가정해 보자! >
# 1) 만들어진 인공지능을 파일로 저장하여 실서버에 배포!!
# 2) 운영비는 15만 달러, 마케팅비는 40만 달러, 연구개발비 13만 달러이고 회사 위치는 Florida에 있다.
# 이 회사의 수익을 예측하세요.
# B 회사는 운영비 11만달러, 마케팅비 60만 달러, 연구개발비 15만 달러 회사위치는 뉴욕
# 이 두 회사의 수익을 예측하세요.
# 학습
# regressor.fit( )
# 예측
# regressor.predict( )
### 예측에 기준이 되는 X_test 형식과 같은 numpy 2차원 데이터로 예측값을 형성해야 한다 ###
# 1차원 데이터값 생성
new_data = np.array([ 0, 1, 0, 130000, 150000, 400000 ])
# 1차원이기 때문에 실행이 안되므로 2차원으로 변경해줘야한다.
new_data = new_data.reshape(1, 6)
new_data
array([[ 0, 1, 0, 130000, 150000, 400000]])
# 새로운 데이터를 기존의 학습시킨 regressor로 예측
regressor.predict( new_data )
array([165102.43212675])
## 2번째 예제
# 2차원 배열로 데이터를 생성하였으므로 상단에 regressor.predict( new_date2 ) 바로 적용 가능
new_date2 = np.array([[ 0, 1, 0, 130000, 150000, 400000 ],[0, 0, 1, 150000, 110000, 600000]])
new_date2
array([[ 0, 1, 0, 130000, 150000, 400000],
[ 0, 0, 1, 150000, 110000, 600000]])
regressor.predict( new_date2 )
array([165102.43212675, 184941.14856592])
# 위에 두가지 방법은 하드 코딩 (직접 일일이 인코딩된 문자열의 위치까지 확인하여 숫자값으로 입력한것 너무 비효율적)
data = {'R&D' : [130000, 150000] , 'Admin' : [150000, 110000], 'Marketing' : [400000, 600000] , 'State' : ['Florida','New York']}
# 딕셔너리로 생성한후 데이터프레임으로 변경하여
new_data = pd.DataFrame(data)
## 번외) 원본 컬렴명과 신규데이터 컬럼명이 맞질않아 오류가 생겨서 컬럼명을 맞춰줌 (.columns 기억하면 좋음)
new_data.columns = df.columns[ 0 : -2+1 ]
new_data
ㄴ 컬럼명까지 맞춰진것 확인
# 본문 에서 설정했던 인코딩 방법으로 인코딩
ct.transform(new_data)
array([[0.0e+00, 1.0e+00, 0.0e+00, 1.3e+05, 1.5e+05, 4.0e+05],
[0.0e+00, 0.0e+00, 1.0e+00, 1.5e+05, 1.1e+05, 6.0e+05]])
# 인코딩된 결과값을 변수로 저장
new_data = ct.transform(new_data)
regressor.predict(new_data)
array([165102.43212675, 184941.14856592])
# 실 서버에 이 인공지능을 활용하려면
# 2개 파일이 필요하다.
# 실제 예측을 위한(가장 중요한) regressor와 인코딩을 위한 ct => 파일로 저장
# 저장하는 방법
import joblib
# joblib.dump 함수 사용 ( 저장할 파일 , '파일 이름 . 확장자는 pkl')
joblib.dump( regressor, 'regressor.pkl' )
joblib.dump( ct, 'ct.pkl' )
ㄴ 파일 생성 확인. 이 파일을 서버로 보내주면 된다.
# 실제 전달받았다고 가정하여 다른 곳에서 불러와서 실행해 보자.
# 새로운 주피터 노트북 혹은 사용 tool 페이지에서 실행
import joblib
ct = joblib.load('ct.pkl')
ct
regressor = joblib.load('regressor.pkl')
regressor
# 신규데이터 예측해 보자
# 스테이트는 캘리포니아, 연구비 18만, 운영비 20만, 마케팅비 15만 일때
# 이회사의 수익은??
data = {'R&D Spend' : [180000] , 'Administration' : [200000], 'Marketing Spend' : [150000] , 'State' : ['California']}
data
{'R&D Spend': [180000],
'Administration': [200000],
'Marketing Spend': [150000],
'State': ['California']}
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
new_data = pd.DataFrame(data)
new_data
ct.transform(new_data)
array([[1.0e+00, 0.0e+00, 0.0e+00, 1.8e+05, 2.0e+05, 1.5e+05]])
new_data = ct.transform(new_data)
regressor.predict(new_data)
array([197402.13702065])