MNN量化—ADMM优化算法
作者:糖心他爸
鏈接:https://zhuanlan.zhihu.com/p/81243626
來源:知乎
目錄
量化的模型建立
ADMM算法
量化的模型建立
現在我們知道了如何做量化推斷,下一步是如何去建立量化模型,或者說我們應該用一個什么樣的方式才能求得量化權重和量化輸入呢?現在已知我們的輸入為fp32,我們想用int8來對原fp32的數據進行表示,其中的轉化關系假設為:
fp32和int8值之間的轉化關系
對于每個數據塊,我們都希望用一個尺度因子來對其進行int8的刻畫,這個尺度因子就是上面式子中的s。經過上面式子之后,我們就能找到int8值和fp32值的一個對應關系。下面我們看看具體的從fp32到int8的值求解過程,這里假設我們已經知道了尺度因子s和階段關系,那么int8的值我們可以通過如下式子獲得:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
int8值的求解過程
其中E是一個分段函數,表示如下。在數軸上看這個函數,是一個分段線性函數,且每個分段的導數為0,在階梯跳躍處的整數點不可導。這個性質是一個很好用的性質,文中后續將會用到這個性質。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
截斷函數
在經過了這一系列過程后,我們就由fp32值得到了int8的值,我們進一步抽象這個過程,為了后面的數學推導做準備。為此,我們再定義一個函數f,我們稱之為編碼函數;反過來f的逆,我們稱之為反編碼函數,由上式我們很容易得到:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
所以,在進行int8值的量化編碼時候,我們執行了如下操作,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
int8的編碼過程,由于進行了數據截斷,該過程不可逆,所以這個是一個信息損失的過程,為了度量損失的程度,我們下面將會介紹兩種思路,也是MNN中解決該問題的兩種算法。其目的都是為了,盡量的在編碼過程中減少信息的損失。一種是從概率分布的角度出發的,目的是讓編碼后的數據分布和編碼前的數據分布盡可能的一致,采用的概率分布度量為KL-散度;另一種是從最優化優化角度出發的,目的是讓編碼后的數據在進行反編碼后的結果,盡可能和原數據接近,采用的度量可以是L2、L1或是廣義范數,但由于L2度量的優化算法求解方便,我們后面也主要基于L2度量進行討論。
?
ADMM算法
MNN中的ADMM算法是從優化的角度出發,來保證編碼前后數據盡可能相似的方法。
首先定義一個目標:也就是希望編碼后的數據(int8)經過反編碼后,跟原數據的”距離“盡可能的接近。該描述可以通過如下數學公式進行刻畫:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
進一步的,假設我們這里選定的度量D為L2度量,f和E的選取跟文章一開始時候介紹的一致,我們可以將上面式子轉化如下目標函數:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
所以,求解尺度因子s的過程,就變成了極小化上述目標函數的過程。而且,上面我們也提到過,E的函數性質很有特點,其在階段域內整點處處不可導,其余點處處導數為0。利用該階段函數的特點,我們很容對上述目標函數進行求解。不過先別急,我們先算算上面目標函數相對于s的導數,如下:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
由于E函數的特點,我們知道E關于s的隱函數求導,在可導處的函數值處處為0,不可導的地方我們先忽略掉(工程實現只需要避免奇點即可)。我們可以將上述式子再進一步簡化如下:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
再得到目標函數關于s的導數公式以后,后續的求解就非常簡單了,我們給出幾種求解思路。
第一種,比較直接的方法,可以采用梯度下降法,流程簡單如下:
## 給定s的初值(可以統計數據的min、max進行強映射初始化) 1、 計算L關于s的梯度 2、 更新s的值 3、 判斷終止條件 4、 更新步長第二種,稍微麻煩點,由第一種算法我們可以知道,步長的選取是啟發性的,收斂性雖然基本都能 得到保證,但如果我們想自適應的計算最優步長,可以采用golden-search的求解方式,求解過程跟梯度下降法很接近,如下:
## 給定s的初值(可以統計數據的min、max進行強映射初始化) 1、 計算L關于s的梯度 2、 計算最優步長 3、 更新s的值 4、 判斷終止條件第三種方式,類似一階收斂(梯度下降、golden-search)和二階收斂(牛頓等)我就不再拓展了;我們這里介紹一下MNN中采用的優化算法,ADMM(交替方向法)。在得到L關于s的梯度公式后,如下圖:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
其進行了一個概念轉化,求解L的極小化問題,轉化成求解L關于s的導數為0的數學問題。但由于函數E中還包含了未知數s,MNN在求解這個數據問題的時候,采用了ADMM的思想,先freezeE中的s變量,進行s值的求解;然后再通過求解后的s去估計E的值,如此反復直到s值收斂到最優。于是上述的式子要弱化為如下公式:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
所以迭代的過程就變成了交替迭代求解如下兩個公式的過程,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
ADMM的迭代過程
于是算法流程可構造如下:
## 給定s的初值(可以統計數據的min、max進行強映射初始化) 1、 計算E的值 2、 更新s的值 3、 判斷終止條件關于三種算法的終止條件判斷,可以采用l2 loss的方式即||L(k+1) - L(L(k)) || / N。也可以采用相對變化率的方式,即當前迭代變化和上次迭代變化的比值。不過其實不判斷問題也不大,如MNN中的源碼實現中是按照最大固定迭代次數來實現的。
總結
以上是生活随笔為你收集整理的MNN量化—ADMM优化算法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pythonfor循环100次_以写代学
- 下一篇: Lattice FPGA 开发工具Dia