生活随笔
收集整理的這篇文章主要介紹了
Java应用梯度下降求解线性SVM模型参考代码
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
下面的代碼是參考網上的,直接執行,主要是為了后續進一步掌握SVM原理而發布。
兩個基本原理還是要去掌握:SVM原理和梯度下降法。
1)SVM分類器:
支持向量機,因其英文名為support vector machine,故一般簡稱SVM,通俗來講,它是一種二類分類模型,其基本模型定義為特征空間上的間隔最大的線性分類器,其學習策略便是間隔最大化,最終可轉化為一個凸二次規劃問題的求解。
我自己的理解是:特征空間上尋找一個最優平面來分類,這個最優的求解其實就是多維的約束規劃問題。
2)梯度下降法:
梯度下降法,就是利用負梯度方向來決定每次迭代的新的搜索方向,使得每次迭代能使待優化的目標函數逐步減小。梯度下降法是2范數下的最速下降法。 最速下降法的一種簡單形式是:x(k+1)=x(k)-a*g(k),其中a稱為學習速率,可以是較小的常數。g(k)是x(k)的梯度。
我自己的理解是:切線求導數,梯度下降法是求解SVM的一種方法。
在實際文本分類中,怎么求解SVM,應該要根據實際來選擇方法,如拉格朗日、對偶、核函數等,如果理解超平面比較復雜的話,可以用二維平面及其點到直線的距離來抽象理解多維度超平面空間的分類。
package sk.svm;import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.StringTokenizer;//梯度下降法
public class SimpleSvm {private int exampleNum;private int exampleDim;private double[] w;private double lambda;private double lr = 0.001;//0.00001private double threshold = 0.001;private double cost;private double[] grad;private double[] yp;public SimpleSvm(double paramLambda){lambda = paramLambda; }private void CostAndGrad(double[][] X,double[] y){//梯度求解cost =0;for(int m=0;m<exampleNum;m++){yp[m]=0;for(int d=0;d<exampleDim;d++){yp[m]+=X[m][d]*w[d];}if(y[m]*yp[m]-1<0){cost += (1-y[m]*yp[m]);}}for(int d=0;d<exampleDim;d++){cost += 0.5*lambda*w[d]*w[d];}for(int d=0;d<exampleDim;d++){grad[d] = Math.abs(lambda*w[d]); for(int m=0;m<exampleNum;m++){if(y[m]*yp[m]-1<0){grad[d]-= y[m]*X[m][d];}}} }private void update(){for(int d=0;d<exampleDim;d++){w[d] -= lr*grad[d];}}public void Train(double[][] X,double[] y,int maxIters){exampleNum = X.length;if(exampleNum <=0) {System.out.println("num of example <=0!");return;}exampleDim = X[0].length;w = new double[exampleDim];grad = new double[exampleDim];yp = new double[exampleNum];for(int iter=0;iter<maxIters;iter++){CostAndGrad(X,y);System.out.println("cost:"+cost);if(cost< threshold){break;}update(); }}private int predict(double[] x){double pre=0;for(int j=0;j<x.length;j++){pre+=x[j]*w[j];}if(pre >=0)//這個閾值一般位于-1到1return 1;else return -1;}public void Test(double[][] testX,double[] testY){int error=0;for(int i=0;i<testX.length;i++){if(predict(testX[i]) != testY[i]){error++;}}System.out.println("total:"+testX.length);System.out.println("error:"+error);System.out.println("error rate:"+((double)error/testX.length));System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));}public static void loadData(double[][]X,double[] y,String trainFile) throws IOException{File file = new File(trainFile);RandomAccessFile raf = new RandomAccessFile(file,"r");StringTokenizer tokenizer,tokenizer2; int index=0;while(true){String line = raf.readLine();if(line == null) break;tokenizer = new StringTokenizer(line," ");y[index] = Double.parseDouble(tokenizer.nextToken());while(tokenizer.hasMoreTokens()){tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");int k = Integer.parseInt(tokenizer2.nextToken());double v = Double.parseDouble(tokenizer2.nextToken());X[index][k] = v; } X[index][0] =1;index++; }}public static void main(String[] args) throws IOException {// TODO Auto-generated method stubdouble[] y = new double[400];double[][] X = new double[400][11];String trainFile = "D:\\tmp\\train_bc";loadData(X,y,trainFile);SimpleSvm svm = new SimpleSvm(0.0001);svm.Train(X,y,7000);double[] test_y = new double[283];double[][] test_X = new double[283][11];String testFile = "D:\\tmp\\test_bc";loadData(test_X,test_y,testFile);svm.Test(test_X, test_y);}
}
源代碼和數據集下載:https://github.com/linger2012/simpleSvm
總結
以上是生活随笔為你收集整理的Java应用梯度下降求解线性SVM模型参考代码的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。