梯度下降法实现softmax回归MATLAB程序
梯度下降法實現softmax回歸MATLAB程序
版權聲明:本文原創,轉載須注明來源。
解決二分類問題時我們通常用Logistic回歸,而解決多分類問題時若果用Logistic回歸,則需要設計多個分類器,這是相當麻煩的事情。softmax回歸可以看做是Logistic回歸的普遍推廣(Logistic回歸可看成softmax回歸在類別數為2時的特殊情況),在多分類問題上softmax回歸是一個有效的工具。
關于softmax回歸算法的理論知識可參考這兩篇博文:http://deeplearning.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92 ;
http://blog.csdn.net/acdreamers/article/details/44663305 。
本文自編mysoftmax_gd函數用于實現梯度下降softmax回歸,代碼如下(鏈接:http://pan.baidu.com/s/1geF2WMJ 密碼:9x3x):
MATLAB程序代碼:
function [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR,varargin) % 該函數用于實現梯度下降法softmax回歸 % 調用方式:[theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR,varargin) % X_test:測試輸入數據 % X:訓練輸入數據,組織為m*p矩陣,m為案例個數,p為加上常數項之后的屬性個數 % label:訓練數據標簽,組織為m*1向量(數值型) % lambda:權重衰減參數weight decay parameter % alpha:梯度下降學習速率 % MAX_ITR:最大迭代次數 % varargin:可選參數,輸入初始迭代的theta系數,若不輸入,則默認隨機選取 % theta:梯度下降法的theta系數尋優結果 % test_pre:測試數據預測標簽 % rata:訓練數據回判正確率% Genlovy Hoo,2016.06.29. genlovhyy@163.com %% 梯度下降尋優 Nin=length(varargin); if Nin>1error('輸入太多參數') % 若可選輸入參數超過1個,則報錯 end [m,p] = size(X); numClasses = length(unique(label)); % 求取標簽類別數 if Nin==0theta = 0.005*randn(p,numClasses); % 若沒有輸入可選參數,則隨機初始化系數 elsetheta=varargin{1}; % 若有輸入可選參數,則將其設定為初始theta系數 end cost=zeros(MAX_ITR,1); % 用于追蹤代價函數的值 for k=1:MAX_ITR[cost(k),grad] = softmax_cost_grad(X,label,lambda,theta); % 計算代價函數值和梯度theta=theta-alpha*grad; % 更新系數 end %% 回判預測 [~,~,Probit] = softmax_cost_grad(X,label,lambda,theta); [~,label_pre] = max(Probit,[],2); index = find(label==label_pre); % 找出預測正確的樣本的位置 rate = length(index)/m; % 計算預測精度 %% 繪制代價函數圖 figure('Name','代價函數值變化圖'); plot(0:MAX_ITR-1,cost) xlabel('迭代次數'); ylabel('代價函數值') title('代價函數值變化圖');% 繪制代價函數值變化圖 %% 測試數據預測 [mt,pt] = size(X_test); Probit_t = zeros(mt,length(unique(label))); for smpt = 1:mtProbit_t(smpt,:) = exp(X_test(smpt,:)*theta)/sum(exp(X_test(smpt,:)*theta)); end [~,test_pre] = max(Probit_t,[],2); function [cost,thetagrad,P] = softmax_cost_grad(X,label,lambda,theta) % 用于計算代價函數值及其梯度 % X:m*p輸入矩陣,m為案例個數,p為加上常數項之后的屬性個數 % label:m*1標簽向量(數值型) % lambda:權重衰減參數weight decay parameter % theta:p*k系數矩陣,k為標簽類別數 % cost:總代價函數值 % thetagrad:梯度矩陣 % P:m*k分類概率矩陣,P(i,j)表示第i個樣本被判別為第j類的概率 m = size(X,1); % 將每個標簽擴展為一個k維橫向量(k為標簽類別數),若樣本i屬于第j類,則 % label_extend(i,j)= 1,否則label_extend(i,j)= 0。 label_extend = [full(sparse(label,1:length(label),1))]'; % 計算預測概率矩陣 P = zeros(m,size(label_extend,2)); for smp = 1:mP(smp,:) = exp(X(smp,:)*theta)/sum(exp(X(smp,:)*theta)); end % 計算代價函數值 cost = -1/m*[label_extend(:)]'*log(P(:))+lambda/2*sum(theta(:).^2); % 計算梯度 thetagrad = -1/m*X'*(label_extend-P)+lambda*theta; clear clc close all load fisheriris % MATLAB自帶數據集 % 對標簽重新編號并準備訓練/測試數據集 index_train = [1:40,51:90,101:140]; index_test = [41:50,91:100,141:150]; species_train = species(index_train); X=[ones(length(species_train),1),meas(index_train,:)]; label = zeros(size(species_train)); label(strcmp('setosa',species_train)) = 1; label(strcmp('versicolor',species_train)) = 2; label(strcmp('virginica',species_train)) = 3; species_test = species(index_test); X_test = [ones(length(species_test),1),meas(index_test,:)]; lambda = 0.004; % 權重衰減參數Weight decay parameter alpha = 0.1; % 學習速率 MAX_ITR=500; % 最大迭代次數 [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR) clear clc close all load MNISTdata % MNIST數據集 % 準備訓練/測試數據集 label = labels(1:9000); % 訓練集標簽 X = [ones(length(label),1),[inputData(:,1:9000)]']; % 訓練集輸入數據 label_test = labels(9001:end); % 測試集標簽 X_test = [ones(length(label_test),1),[inputData(:,9001:end)]']; % 測試輸入數據 lambda = 0.004; % 權重衰減參數Weight decay parameter alpha = 0.1; % 學習速率 MAX_ITR=100; % 最大迭代次數 [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR) index_t = find(label_test==test_pre); % 找出預測正確的樣本的位置 rate_test = length(index_t)/length(label_test); % 計算預測精度水平有限,敬請指正交流。genlovhyy@163.com 。
參考資料:
【1】:http://deeplearning.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
【2】:http://blog.csdn.net/acdreamers/article/details/44663305
總結
以上是生活随笔為你收集整理的梯度下降法实现softmax回归MATLAB程序的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: .net core创建区域(Areas
- 下一篇: java socket / No buf