KNN
目录
导入包
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import plotly.graph_objects as go
导入数据
= pd.read_csv("./datasets/Social_Network_Ads.csv")
data = data.iloc[:,[2,3]].values
X = data.iloc[:,4].values
Y # scatter = go.Scatter(x=X[:,0],y=X[:,1],mode='markers',marker={'color':Y})
# fig = go.Figure(scatter)
# fig.show()
= train_test_split(X,Y,test_size=0.25,random_state=0) X_train,X_test,Y_train,Y_test
标准化
from sklearn.preprocessing import StandardScaler
= StandardScaler()
sca = sca.fit_transform(X_train)
X_train = sca.transform(X_test) X_test
训练模型
from sklearn.neighbors import KNeighborsClassifier
= KNeighborsClassifier(n_neighbors=5,p=2)
model model.fit(X_train,Y_train)
KNeighborsClassifier()
模型得分
model.score(X_test,Y_test)
0.93