线性回归实例-鸢尾花数据集
文章目錄
- 一、具體實現步驟
- 1. 導入Iris鳶尾花數據集
- 2. 提取花瓣數據
- 3. 拆分數據
- 4. 訓練模型
- 二、可視化結果展示
- 1. 訓練集
- 2. 測試集
- 三、相關知識點講解
- 1. train_test_split()函數
- 2. LinearRegression()函數
- 3. 散點圖與折線統計圖的繪制
這篇文章中,我們要通過鳶尾花的花瓣長度預測花瓣寬度
- 環境:Python3.6.5
- 編譯器:jupyter notebook
- 代碼|數據:微信公眾號(明天依舊可好)中回復:第1天
一、具體實現步驟
1. 導入Iris鳶尾花數據集
Iris鳶尾花數據集共有150條記錄,分別是:
- 50條山鳶尾 (Iris-setosa)
- 50條變色鳶尾(Iris-versicolor)
- 50條維吉尼亞鳶尾(Iris-virginica)
2. 提取花瓣數據
下面我們提取數據集中花瓣寬度與花瓣長度數據,將花瓣數據分為訓練數據與測試數據,訓練數據用于訓練線性回歸模型,測試數據用于檢測我們的模型的準確率。
最終我們要達到的效果是:輸入花瓣寬度,通過模型預測花瓣寬度。
X = dataset["花瓣-length"] Y = dataset["花瓣-width"] X = X.reshape(len(X),1) Y = Y.reshape(len(Y),1)3. 拆分數據
將數據集拆分數據集成訓練集、測試集
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)4. 訓練模型
這里我們需要將我們的訓練數據喂給模型進行訓練。
from sklearn.linear_model import LinearRegression regressor = LinearRegression() regressor = regressor.fit(X_train, Y_train)二、可視化結果展示
1. 訓練集
將訓練集中每一朵花的花瓣數據與線性回歸模型預測的結果放到同一張統計圖中。
import matplotlib.pyplot as pltplt.scatter(X_train, Y_train, color='red') plt.plot(X_train, regressor.predict(X_train), color='green') plt.xlabel("Iris-length") plt.ylabel("Iris-width") plt.title("This is train dataset-kzb") plt.show()紅色的點是訓練數據集中的花瓣數據,我們不難看出花瓣長度與寬度是一個線性關系,綠色的線是我們模型擬合的結果。
2. 測試集
將測試集中每一朵花的花瓣數據與線性回歸模型預測的結果放到同一張統計圖中。
plt.scatter(X_test, Y_test, color='blue') plt.plot(X_train, regressor.predict(X_train), color='green') plt.xlabel("Iris-length") plt.ylabel("Iris-width") plt.title("This is test dataset-kzb") plt.show()綠色的點是測試數據集中的花瓣數據,我們可以看出這部分數據也是符合線性關系的,隨著集的增大,線性關系會更加明顯。
三、相關知識點講解
1. train_test_split()函數
train_test_split():將數據集劃分為測試集與訓練集。
- X:所要劃分的整體數據的特征集;
- Y:所要劃分的整體數據的結果;
- test_size:測試集數據量在整體數據量中的占比(可以理解為X_test與X的比值);
- random_state:①若不填或者填0,每次生成的數據都是隨機,可能不一樣。②若為整數,每次生成的數據都相同;
2. LinearRegression()函數
sklearn.linear_model包實現了廣義線性模型,包括線性回歸、Ridge回歸、Bayesian回歸等。LinearRegression是其中較為簡單的線性回歸模型。
解釋一下什么是回歸:回歸最簡單的定義是,給出一個點集D,用一個函數去擬合這個點集,并且使得點集與擬合函數間的誤差最小,如果這個函數曲線是一條直線,那就被稱為線性回歸,如果曲線是一條二次曲線,就被稱為二次回歸。
3. 散點圖與折線統計圖的繪制
plt.scatter():繪畫出數據的散點圖
plt.plot():繪畫出依據模型(LinearRegression的線性回歸模型)生成的直線
更詳細的介紹可以參考:【Matplotlib可視化系列教程】
【機器學習100天目錄】
【機器學習第2天:線性回歸(理論篇)】
【機器學習第3天:預測汽車的燃油效率】
【機器學習第4天:預測1立方米混凝土抗壓強度】
【機器學習第5天:邏輯回歸】
如有錯誤歡迎指教,有問題的也可以加入QQ群(1149530473)向我提問,關注微信公眾號(明天依舊可好)和我同步。
總結
以上是生活随笔為你收集整理的线性回归实例-鸢尾花数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 机器学习第2天:简单线性回归模型
- 下一篇: 线性回归的概念