机器学习第7天:深入了解逻辑回归
文章目錄
- 一、邏輯回歸是什么
- 二、邏輯回歸的代價函數
- 1. 公式:
- 2. 公式推導過程:
- 2.1. 尋找預測函數
- 2.2. 構造代價函數
- 三、梯度下降法求J(θ)的最小值
- 四、代碼展示
一、邏輯回歸是什么
簡單來說, 邏輯回歸(Logistic Regression)是一種用于解決二分類(0 or 1)問題的機器學習方法,用于估計某種事物的可能性。比如某用戶購買某商品的可能性,某病人患有某種疾病的可能性,以及某廣告被用戶點擊的可能性等。
邏輯回歸是為了解決分類問題,根據一些已知的訓練集訓練好模型,再對新的數據進行預測屬于哪個類。
邏輯回歸(Logistic Regression)與線性回歸(Linear Regression)都是一種廣義線性模型(generalized linear model)。邏輯回歸假設因變量 y 服從伯努利分布,而線性回歸假設因變量 y 服從高斯分布。
二、邏輯回歸的代價函數
1. 公式:
綜合起來為:
其中
2. 公式推導過程:
代價函數的推導分兩步進行:
2.1. 尋找預測函數
Logistic Regression雖然名字里帶“回歸”,但是它實際上是一種分類方法,用于兩分類問題(即輸出只有兩種),顯然,預測函數的輸出必須是兩個值(分別代表兩個類別),所以利用了Logistic函數(或稱為Sigmoid函數)。
sigmoid函數是一個s形的曲線,它的取值在[0, 1]之間,在遠離0的地方函數的值會很快接近0或者1。它的這個特性對于解決二分類問題十分重要。
Sigmoid函數:
接下來需要確定數據劃分的邊界類型,對于圖1和圖2中的兩種數據分布,顯然圖1需要一個線性的邊界,而圖2需要一個非線性的邊界。接下來我們只討論線性邊界的情況。
圖1 圖2對于線性邊界的情況,邊界形式如下:
構造預測函數為:
hθ(x)函數的值有特殊的含義,它表示結果取1的概率,因此對于輸入x分類結果為類別1和類別0的概率分別為:
2.2. 構造代價函數
上面的n改成m,筆誤。
三、梯度下降法求J(θ)的最小值
θ更新過程可以寫成:
四、代碼展示
def LogisticRegression():data = loadtxtAndcsv_data("data2.txt", ",", np.float64) X = data[:,0:-1]y = data[:,-1]plot_data(X,y) # 作圖X = mapFeature(X[:,0],X[:,1]) #映射為多項式initial_theta = np.zeros((X.shape[1],1))#初始化thetainitial_lambda = 0.1 #初始化正則化系數,一般取0.01,0.1,1.....J = costFunction(initial_theta,X,y,initial_lambda) #計算一下給定初始化的theta和lambda求出的代價Jprint(J) #輸出一下計算的值,應該為0.693147#result = optimize.fmin(costFunction, initial_theta, args=(X,y,initial_lambda)) #直接使用最小化的方法,效果不好'''調用scipy中的優(yōu)化算法fmin_bfgs(擬牛頓法Broyden-Fletcher-Goldfarb-Shanno)- costFunction是自己實現(xiàn)的一個求代價的函數,- initial_theta表示初始化的值,- fprime指定costFunction的梯度- args是其余測參數,以元組的形式傳入,最后會將最小化costFunction的theta返回 '''result = optimize.fmin_bfgs(costFunction, initial_theta, fprime=gradient, args=(X,y,initial_lambda)) p = predict(X, result) #預測print(u'在訓練集上的準確度為%f%%'%np.mean(np.float64(p==y)*100)) # 與真實值比較,p==y返回True,轉化為float X = data[:,0:-1]y = data[:,-1] plotDecisionBoundary(result,X,y) #畫決策邊界感覺有困難可以先放著,后期會進行更加具體的介紹,知道這么幾個公式就好了。
參考文章:
Logistic回歸計算過程的推導
邏輯回歸(Logistic Regression)
Coursera ML筆記 - 邏輯回歸
邏輯回歸 - 理論篇
總結
以上是生活随笔為你收集整理的机器学习第7天:深入了解逻辑回归的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 实战项目三:爬取QQ群中的人员信息
- 下一篇: 机器学习第8天:IPyhon与Jupyt