SuperPoint:深度学习特征点+描述子
【原文鏈接】:https://www.vincentqin.tech/posts/superpoint/
本文出自近幾年備受矚目的創業公司MagicLeap,發表在CVPR 2018,一作Daniel DeTone,[paper],[slides],[code]。
這篇文章設計了一種自監督網絡框架,能夠同時提取特征點的位置以及描述子。相比于patch-based方法,本文提出的算法能夠在原始圖像提取到像素級精度的特征點的位置及其描述子。
本文提出了一種單應性適應(Homographic Adaptation)的策略以增強特征點的復檢率以及跨域的實用性(這里跨域指的是synthetic-to-real的能力,網絡模型在虛擬數據集上訓練完成,同樣也可以在真實場景下表現優異的能力)。
介紹
諸多應用(諸如SLAM/SfM/相機標定/立體匹配)的首要一步就是特征點提取,這里的特征點指的是能夠在不同光照&不同視角下都能夠穩定且可重復檢測的2D圖像點位置。
基于CNN的算法幾乎在以圖像作為輸入的所有領域表現出相比于人類特征工程更加優秀的表達能力。目前已經有一些工作做類似的任務,例如人體位姿估計,目標檢測以及室內布局估計等。這些算法以通常以大量的人工標注作為GT,這些精心設計的網絡用來訓練以得到人體上的角點,例如嘴唇的邊緣點亦或人體的關節點,但是這里的問題是這里的點實際是ill-defined(我的理解是,這些點有可能是特征點,但僅僅是一個大概的位置,是特征點的子集,并沒有真正的把特征點的概念定義清楚)。
本文采用了非人工監督的方法提取真實場景的特征點。本文設計了一個由特征點檢測器監督的具有偽真值數據集,而非是大量的人工標記。為了得到偽真值,本文首先在大量的虛擬數據集上訓練了一個全卷積網絡(FCNN),這些虛擬數據集由一些基本圖形組成,例如有線段、三角形、矩形和立方體等,這些基本圖形具有沒有爭議的特征點位置,文中稱這些特征點為MagicPoint,這個pre-trained的檢測器就是MagicPoint檢測器。這些MagicPoint在虛擬場景的中檢測特征點的性能明顯優于傳統方式,但是在真實的復雜場景中表現不佳,此時作者提出了一種多尺度多變換的方法Homographic Adaptation。對于輸入圖像而言,Homographic Adaptation通過對圖像進行多次不同的尺度/角度變換來幫助網絡能夠在不同視角不同尺度觀測到特征點。
綜上:SuperPoint = MagicPoint+Homographic Adaptation
算法優劣對比
- 基于圖像塊的算法導致特征點位置精度不夠準確;
- 特征點與描述子分開進行訓練導致運算資源的浪費,網絡不夠精簡,實時性不足;或者僅僅訓練特征點或者描述子的一種,不能用同一個網絡進行聯合訓練;
網絡結構
上圖可見特征點檢測器以及描述子網絡共享一個單一的前向encoder,只是在decoder時采用了不同的結構,根據任務的不同學習不同的網絡參數。這也是本框架與其他網絡的不同之處:其他網絡采用的是先訓練好特征點檢測網絡,然后再去進行對特征點描述網絡進行訓練。
網絡共分成以下4個主要部分,在此進行詳述:
1. Shared Encoder 共享的編碼網絡
從上圖可以看到,整體而言,本質上有兩個網絡,只是前半部分共享了一部分而已。本文利用了VGG-style的encoder以用于降低圖像尺寸,encoder包括卷積層,max-pooling層,以及非線性激活層。通過3個max-pooling層將圖像的尺寸變成Hc=H/8H_c = H/8Hc?=H/8和Hc=H/8H_c = H/8Hc?=H/8,經過encoder之后,圖像由I∈RH×WI \in \mathcal{R}^{H \times W}I∈RH×W變為張量B∈RHc×Wc×F\mathcal{B} \in \mathbb{R}^{H_c \times W_c \times F}B∈RHc?×Wc?×F
2. Interest Point Decoder
這里介紹的是特征點的解碼端。每個像素的經過該解碼器的輸出是該像素是特征點的概率(probability of “point-ness”)。
通常而言,我們可以通過反卷積得到上采樣的圖像,但是這種操作會導致計算量的驟增以及會引入一種“checkerboard artifacts”。因此本文設計了一種帶有“特定解碼器”(這種解碼器沒有參數)的特征點檢測頭以減小模型計算量(子像素卷積)。
例如:輸入張量的維度是RHc×Wc×65\mathbb{R}^{H_c \times W_c \times 65}RHc?×Wc?×65,輸出維度RH×W\mathbb{R}^{H \times W}RH×W,即圖像的尺寸。這里的65表示原圖8×88 \times 88×8的局部區域,加上一個非特征點dustbin。通過在channel維度上做softmax,非特征點dustbin會被刪除,同時會做一步圖像的reshape:RHc×Wc×64?RH×W\mathbb{R}^{H_c \times W_c \times 64} \Rightarrow \mathbb{R}^{H \times W}RHc?×Wc?×64?RH×W 。(這就是**子像素卷積**的意思,俗稱像素洗牌)
拋出特征點解碼端部分代碼:
# Compute the dense keypoint scores cPa = self.relu(self.convPa(x)) scores = self.convPb(cPa) # DIM: N x 65 x H/8 x W/8 scores = torch.nn.functional.softmax(scores, 1)[:, :-1] # DIM: N x 64 x H/8 x W/8 b, _, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) # DIM: N x H/8 x W/8 x 8 x 8 scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) # DIM: N x H x W這個過程看似比較繁瑣,但是這其實就是一個由depth to space的過程,以N = 1為例,上述過程如下圖所示:
上圖中所示的3個藍色小塊的就是對應的一個cell經過depth to space后得到的,易知其尺寸是8×88 \times 88×8。
注意 :這里解釋一下為何此作者設置選擇增加一個dustbin通道,以及為何先進行softmax再進行slice操作,先進行slice再進行softmax是否可行?(scores = torch.nn.functional.softmax(scores, 1)[:, :-1])
之所以要設置65個通道,這是因為算法要應對不存在特征點的情況。注意到之后的一步中使用了softmax,也就是說沿著通道維度把各個數值通過運算后加和為1。如果沒有Dustbin通道,這里就會產生一個問題:若該cell處沒有特征點,此時經過softmax后,每個通道上的響應就會出現受到噪聲干擾造成異常隨機,在隨后的特征點選擇一步中會將非特征點判定為特征,這個過程由下圖左圖所示。在添加Dustbin之后,在沒有特征的情況下,只有在Dustbin通道的響應值很大,在后續的特征點判斷階段,此時該圖像塊的響應都很小,會成功判定為無特征點,這個過程由下圖右圖所示。
上述過程中得到的scores就是圖像上特征點的概率(或者叫做特征響應,后文中響應值即表示概率值),概率越大,該點越有可能是特征點。之后作者進行了一步nms,即非極大值抑制(simple_nms的實現見文末),隨后選擇響應值較大的位置作為特征點。
scores = simple_nms(scores, self.config['nms_radius']) keypoints = [ torch.nonzero(s > self.config['keypoint_threshold']) for s in scores]nms的效果如下,左圖是未使用nms時score的樣子,響應值極大的位置周圍也聚集著響應較大的點,如果不進行nms,特征點將會很集中;右圖是進行nms操作后的score,響應值極大的位置周圍的響應為0。
nms前后對應的特征點的位置如下所示,可見nms對于避免特征點位置過于集中起到了比較大的作用。
熟悉SuperPoint的同學應該注意到了,Daniel在CVPR 2018公開的實現中nms在特征點提取之后,而Sarlin于CVPR 2020年公開SuperGlue的同時對SuperPoint進行了重構,后者在score上進行nms,這兩種實現上存在一些差異。
下面給出的是Daniel在CVPR 2018開源的SuperPoint推理代碼節選。
nodust = nodust.transpose(1, 2, 0) heatmap = np.reshape(nodust, [Hc, Wc, self.cell, self.cell]) heatmap = np.transpose(heatmap, [0, 2, 1, 3]) heatmap = np.reshape(heatmap, [Hc*self.cell, Wc*self.cell]) xs, ys = np.where(heatmap >= self.conf_thresh) # Confidence threshold. if len(xs) == 0:return np.zeros((3, 0)), None, None pts = np.zeros((3, len(xs))) # Populate point data sized 3xN. pts[0, :] = ys pts[1, :] = xs pts[2, :] = heatmap[xs, ys] pts, _ = self.nms_fast(pts, H, W, dist_thresh=self.nms_dist) # Apply NMS.但Sarlin為何要這么做呢?本人在Github上提交了一個#issue112咨詢了Sarlin,如下是他的回復,總結起來就重構后的代碼優勢有兩點:1. 更加快速,能夠在GPU上運行,常數級時間復雜度;2. 支持多圖像輸入。
3. Descriptor Decoder
首先利用類似于UCN的網絡得到一個半稠密的描述子(此處參考文獻UCN),這樣可以減少算法訓練內存開銷同時減少算法運行時間。之后通過雙三次多項式插值得到其余描述,然后通過L2-normalizes歸一化描述子得到統一的長度描述。特征維度由D∈RHc×Wc×D\mathcal{D} \in \mathbb{R}^{H_c \times W_c \times D}D∈RHc?×Wc?×D變為RH×W×D\mathbb{R}^{H\times W \times D}RH×W×D 。
由特征點得到其描述子的過程文中沒有細講,看了一下源代碼就明白了。其實該過程主要用了一個函數即grid_sample,畫了一個草圖作為解釋。
- 圖像尺寸歸一化:首先對圖像的尺寸進行歸一化,(-1,-1)表示原來圖像的(0,0)位置,(1,1)表示原來圖像的(H-1,W-1)位置,這樣一來,特征點的位置也被歸一化到了相應的位置。
- 構建grid:將歸一化后的特征點羅列起來,構成一個尺度為1*1*K*2的張量,其中K表示特征數量,2分別表示xy坐標。
- 特征點位置反歸一化:根據輸入張量的H與W對grid(1,1,0,:)(表示第一個特征點,其余特征點類似)進行反歸一化,其實就是按照比例進行縮放+平移,得到反歸一化特征點在張量某個slice(通道)上的位置;但是這個位置可能并非為整像素,此時要對其進行雙線性插值補齊,然后其余slice按照同樣的方式進行雙線性插值。注:代碼中實際的就是雙線性插值,并非文中講的雙三次插值;
- 輸出維度:1*C*1*K。
描述子解碼部分代碼如下:
# Compute the dense descriptors cDa = self.relu(self.convDa(x)) descriptors = self.convDb(cDa) # DIM: N x 256 x H/8 x W/8 descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) #按通道進行歸一化# Extract descriptors # 根據特征點位置插值得到描述子, DIM: N x 256 x Mdescriptors = [sample_descriptors(k[None], d[None], 8)[0]for k, d in zip(keypoints, descriptors)]4. 誤差構建
L(X,X′,D,D′;Y,Y′,S)=Lp(X,Y)+Lp(X′,Y′)+λLd(D,D′,S)\begin{array}{l}{\mathcal{L}\left(\mathcal{X}, \mathcal{X}^{\prime}, \mathcal{D}, \mathcal{D}^{\prime} ; Y, Y^{\prime}, S\right)=} \\ {\qquad \mathcal{L}_{p}(\mathcal{X}, Y)+\mathcal{L}_{p}\left(\mathcal{X}^{\prime}, Y^{\prime}\right)+\lambda \mathcal{L}_ozvdkddzhkzd\left(\mathcal{D}, \mathcal{D}^{\prime}, S\right)}\end{array} L(X,X′,D,D′;Y,Y′,S)=Lp?(X,Y)+Lp?(X′,Y′)+λLd?(D,D′,S)?
可見損失函數由兩項組成,其中一項為特征點檢測lossLp\mathcal{L}_{p}Lp? ,另外一項是描述子的lossLd\mathcal{L}_ozvdkddzhkzdLd?。
對于檢測項loss,此時采用了交叉熵損失函數:
Lp(X,Y)=1HcWc∑h=1w=1Hc,Wclp(xhw;yhw)\mathcal{L}_{p}(\mathcal{X}, Y)=\frac{1}{H_{c} W_{c}} \sum_{h=1 \atop w=1}^{H_{c}, W_{c}} l_{p}\left(\mathbf{x}_{h w} ; y_{h w}\right) Lp?(X,Y)=Hc?Wc?1?w=1h=1?∑Hc?,Wc??lp?(xhw?;yhw?)
其中:
lp(xhw;y)=?log?(exp?(xhwy)∑k=165exp?(xhwk))l_{p}\left(\mathbf{x}_{h w} ; y\right)=-\log \left(\frac{\exp \left(\mathbf{x}_{h w y}\right)}{\sum_{k=1}^{65} \exp \left(\mathbf{x}_{h w k}\right)}\right) lp?(xhw?;y)=?log(∑k=165?exp(xhwk?)exp(xhwy?)?)
此時類似于一個多分類任務,log?\loglog 運算內部就是cell中元素為特征點的概率(即softmax之后的值),即樣本xhw\mathbf{x}_{hw}xhw?屬于特征的概率。這是一個2D location classifier,每個8x8的范圍內只能有一個特征點,即圖像中最多有$H \times W / 64 $個SuperPoint特征點。
描述子的損失函數:
Ld(D,D′,S)=1(HcWc)2∑h=1w=1Hc,Wc∑h′=1w′=1Hc,Wcld(dhw,dh′w′′;shwh′w′)\mathcal{L}_ozvdkddzhkzd\left(\mathcal{D}, \mathcal{D}^{\prime}, S\right)=\frac{1}{\left(H_{c} W_{c}\right)^{2}} \sum_{h=1 \atop w=1}^{H_{c}, W_{c}} \sum_{h^{\prime}=1 \atop w^{\prime}=1}^{H_{c}, W_{c}} l_ozvdkddzhkzd\left(\mathbfozvdkddzhkzd_{h w}, \mathbfozvdkddzhkzd_{h^{\prime} w^{\prime}}^{\prime} ; s_{h w h^{\prime} w^{\prime}}\right) Ld?(D,D′,S)=(Hc?Wc?)21?w=1h=1?∑Hc?,Wc??w′=1h′=1?∑Hc?,Wc??ld?(dhw?,dh′w′′?;shwh′w′?)
其中ldl_ozvdkddzhkzdld?為Hinge-loss(合頁損失函數,用于SVM,如支持向量的軟間隔,可以保證最后解的稀疏性);
ld(d,d′;s)=λd?s?max?(0,mp?dTd′)+(1?s)?max?(0,dTd′?mn)l_ozvdkddzhkzd\left(\mathbfozvdkddzhkzd, \mathbfozvdkddzhkzd^{\prime} ; s\right)=\lambda_ozvdkddzhkzd * s * \max \left(0, m_{p}-\mathbfozvdkddzhkzd^{T} \mathbfozvdkddzhkzd^{\prime}\right)+(1-s) * \max \left(0, \mathbfozvdkddzhkzd^{T} \mathbfozvdkddzhkzd^{\prime}-m_{n}\right) ld?(d,d′;s)=λd??s?max(0,mp??dTd′)+(1?s)?max(0,dTd′?mn?)
同時指示函數為shwh′w′s_{h w h^{\prime} w^{\prime}}shwh′w′?,SSS表示所有正確匹配對集合:
shwh′w′={1,if?∥Hphw^?ph′w′∥≤80,otherwise?s_{h w h^{\prime} w^{\prime}}=\left\{\begin{array}{ll}{1,} & {\text { if }\left\|\widehat{\mathcal{H} \mathbf{p}_{h w}}-\mathbf{p}_{h^{\prime} w^{\prime}}\right\| \leq 8} \\ {0,} & {\text { otherwise }}\end{array}\right. shwh′w′?={1,0,??if?∥∥∥?Hphw???ph′w′?∥∥∥?≤8?otherwise??
上式中的p\mathbf{p}p是cell的中心點坐標,Hp\mathcal{H} \mathbf{p}Hp與p′\mathbf{p}^{\prime}p′的距離小于8個pixel的認為是正確的匹配,這其實對應于cell上的的1個pixel。
讓我們仔細看一下這個損失函數,這其實是一個Double margin Siamese loss。當正例描述子余弦相似度dTd′\mathbfozvdkddzhkzd^T\mathbfozvdkddzhkzd^{\prime}dTd′大于mpm_pmp?時,此時不需要懲罰;但如果該相似度較小時,此時就要懲罰了;負樣本時我們的目標是讓dTd′\mathbfozvdkddzhkzd^T\mathbfozvdkddzhkzd^{\prime}dTd′變小,但網絡性能不佳時可能這個值很大(大于上式中的mnm_nmn?),此時要懲罰這種現象,網絡權重經過調整后使得該loss降低,對應的描述子相似度降低;
讓我們再看一下這個所謂的Double margin Siamese loss,上圖示中的連線表示distdistdist函數。想象一下,我們希望正例𝑑𝑖𝑠𝑡(𝑑,𝑑′)𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})dist(d,d′)越小越好,如果𝑑𝑖𝑠𝑡(𝑑,𝑑′)>𝑚𝑝1𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})>𝑚_{𝑝1}dist(d,d′)>mp1?,網絡要懲罰這種現象,會使得𝑑𝑖𝑠𝑡(𝑑,𝑑′)<𝑚𝑝1𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})<𝑚_{𝑝1}dist(d,d′)<mp1?.相應的的我們希望負例𝑑𝑖𝑠𝑡(𝑑,𝑑′)𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})dist(d,d′)越大越好,如果𝑑𝑖𝑠𝑡(𝑑,𝑑′)<𝑚𝑛1𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})<𝑚_{𝑛1}dist(d,d′)<mn1?,網絡要懲罰這種現象,最終會使得𝑑𝑖𝑠𝑡(𝑑,𝑑′)>𝑚𝑛1𝑑𝑖𝑠𝑡(𝑑,𝑑^{\prime})>𝑚_{𝑛1}dist(d,d′)>mn1?。
網絡訓練
本文一共設計了兩個網絡,一個是BaseDetector,用于檢測角點(注意,此處提取的并不是最終輸出的特征點,可以理解為候選的特征點),另一個是SuperPoint網絡,輸出特征點和描述子。
網絡的訓練共分為三個步驟:
這里需要注意的是,聯合訓練使用的單應變換相較于Homographic Adaptation中設置的單應變換更加嚴格,即沒有特別離譜的in-plane的旋轉。作者在論文中提到,這是由于在HPatches數據集中沒有這樣的數據才進行這種設置,原話是“we avoid sampling extreme in-plane rotations as they are rarely seen in HPatches”,這也是為什么SuperPoint無法有效地應對in-plane rotations的原因。
預訓練Magic Point
此處參考作者之前發表的一篇論文**[Toward Geometric Deep SLAM]**,其實就是MagicPoint,它僅僅保留了SuperPoint的主干網絡以及特征點解碼端,即SuperPoint的檢測端就是MagicPoint。
Homographic Adaptation
算法在虛擬數據集上表現極其優秀,但是在真實場景下表示沒有達到預期,此時本文進行了Homographic Adaptation。
作者使用的數據集是MS-COCO,為了使網絡的泛化能力更強,本文不僅使用原始了原始圖片,而且對每張圖片進行隨機的旋轉和縮放形成新的圖片,新的圖片也被用來進行識別。這一步其實就類似于訓練里常用的數據增強。經過一系列的單應變換之后特征點的復檢率以及普適性得以增強。值得注意的是,在實際訓練時,這里采用了迭代使用單應變換的方式,例如使用優化后的特征點檢測器重新進行單應變換進行訓練,然后又可以得到更新后的檢測器,如此迭代優化,這就是所謂的self-supervisd。
最后的關鍵點檢測器,即F^(I;fθ)\hat{F}\left(I ; f_{\theta}\right)F^(I;fθ?),可以表示為再所有隨機單應變換/反變換的聚合:
F^(I;fθ)=1Nh∑i=1NhHi?1fθ(Hi(I))\hat{F}\left(I ; f_{\theta}\right)=\frac{1}{N_{h}} \sum_{i=1}^{N_{h}} \mathcal{H}_{i}^{-1} f_{\theta}\left(\mathcal{H}_{i}(I)\right) F^(I;fθ?)=Nh?1?i=1∑Nh??Hi?1?fθ?(Hi?(I))
構建殘差,迭代優化描述子以及檢測器
利用上面網絡得到的關鍵點位置以及描述子表示構建殘差,利用ADAM進行優化。
實驗結果
總結
未來工作:
作者最后提到,他相信該網絡能夠解決SLAM或者SfM領域的數據關聯*,并且*learning-based前端可以使得諸如機器人或者AR等應用獲得更加魯棒。
代碼
以下給出的是Sarlin在SuperGlue代碼中重構的SuperPoint前向推理代碼,與Daniel于2018年的原始版本有些差異。不過Sarlin的版本與原版結果幾乎一致,另外增加多batch的支持,執行效率更高。
# %BANNER_BEGIN% # --------------------------------------------------------------------- # %COPYRIGHT_BEGIN% # # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL # # Unpublished Copyright (c) 2020 # Magic Leap, Inc., All Rights Reserved. # # NOTICE: All information contained herein is, and remains the property # of COMPANY. The intellectual and technical concepts contained herein # are proprietary to COMPANY and may be covered by U.S. and Foreign # Patents, patents in process, and are protected by trade secret or # copyright law. Dissemination of this information or reproduction of # this material is strictly forbidden unless prior written permission is # obtained from COMPANY. Access to the source code contained herein is # hereby forbidden to anyone except current COMPANY employees, managers # or contractors who have executed Confidentiality and Non-disclosure # agreements explicitly covering such access. # # The copyright notice above does not evidence any actual or intended # publication or disclosure of this source code, which includes # information that is confidential and/or proprietary, and is a trade # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. # # %COPYRIGHT_END% # ---------------------------------------------------------------------- # %AUTHORS_BEGIN% # # Originating Authors: Paul-Edouard Sarlin # # %AUTHORS_END% # --------------------------------------------------------------------*/ # %BANNER_END%from pathlib import Path import torch from torch import nndef simple_nms(scores, nms_radius: int):""" Fast Non-maximum suppression to remove nearby points """assert(nms_radius >= 0)def max_pool(x):return torch.nn.functional.max_pool2d(x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)zeros = torch.zeros_like(scores)max_mask = scores == max_pool(scores)for _ in range(2):supp_mask = max_pool(max_mask.float()) > 0supp_scores = torch.where(supp_mask, zeros, scores)new_max_mask = supp_scores == max_pool(supp_scores)max_mask = max_mask | (new_max_mask & (~supp_mask))return torch.where(max_mask, scores, zeros)def remove_borders(keypoints, scores, border: int, height: int, width: int):""" Removes keypoints too close to the border """mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))mask = mask_h & mask_wreturn keypoints[mask], scores[mask]def top_k_keypoints(keypoints, scores, k: int):if k >= len(keypoints):return keypoints, scoresscores, indices = torch.topk(scores, k, dim=0)return keypoints[indices], scoresdef sample_descriptors(keypoints, descriptors, s: int = 8):""" Interpolate descriptors at keypoint locations """b, c, h, w = descriptors.shapekeypoints = keypoints - s / 2 + 0.5keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],).to(keypoints)[None]keypoints = keypoints*2 - 1 # normalize to (-1, 1)args = {'align_corners': True} if torch.__version__ >= '1.3' else {}descriptors = torch.nn.functional.grid_sample(descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)descriptors = torch.nn.functional.normalize(descriptors.reshape(b, c, -1), p=2, dim=1)return descriptorsclass SuperPoint(nn.Module):"""SuperPoint Convolutional Detector and DescriptorSuperPoint: Self-Supervised Interest Point Detection andDescription. Daniel DeTone, Tomasz Malisiewicz, and AndrewRabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629"""default_config = {'descriptor_dim': 256,'nms_radius': 4,'keypoint_threshold': 0.005,'max_keypoints': -1,'remove_borders': 4,}def __init__(self, config):super().__init__()self.config = {**self.default_config, **config}self.relu = nn.ReLU(inplace=True)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)self.convDb = nn.Conv2d(c5, self.config['descriptor_dim'],kernel_size=1, stride=1, padding=0)path = Path(__file__).parent / 'weights/superpoint_v1.pth'self.load_state_dict(torch.load(str(path)))mk = self.config['max_keypoints']if mk == 0 or mk < -1:raise ValueError('\"max_keypoints\" must be positive or \"-1\"')print('Loaded SuperPoint model')def forward(self, data):""" Compute keypoints, scores, descriptors for image """# Shared Encoderx = self.relu(self.conv1a(data['image']))x = self.relu(self.conv1b(x))x = self.pool(x)x = self.relu(self.conv2a(x))x = self.relu(self.conv2b(x))x = self.pool(x)x = self.relu(self.conv3a(x))x = self.relu(self.conv3b(x))x = self.pool(x)x = self.relu(self.conv4a(x))x = self.relu(self.conv4b(x))# Compute the dense keypoint scorescPa = self.relu(self.convPa(x))scores = self.convPb(cPa)scores = torch.nn.functional.softmax(scores, 1)[:, :-1]b, _, h, w = scores.shapescores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)scores = simple_nms(scores, self.config['nms_radius'])# Extract keypointskeypoints = [torch.nonzero(s > self.config['keypoint_threshold'])for s in scores]scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]# Discard keypoints near the image borderskeypoints, scores = list(zip(*[remove_borders(k, s, self.config['remove_borders'], h*8, w*8)for k, s in zip(keypoints, scores)]))# Keep the k keypoints with highest scoreif self.config['max_keypoints'] >= 0:keypoints, scores = list(zip(*[top_k_keypoints(k, s, self.config['max_keypoints'])for k, s in zip(keypoints, scores)]))# Convert (h, w) to (x, y)keypoints = [torch.flip(k, [1]).float() for k in keypoints]# Compute the dense descriptorscDa = self.relu(self.convDa(x))descriptors = self.convDb(cDa)descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)# Extract descriptorsdescriptors = [sample_descriptors(k[None], d[None], 8)[0]for k, d in zip(keypoints, descriptors)]return {'keypoints': keypoints,'scores': scores,'descriptors': descriptors,}歡迎大家關注我的公眾號,最新文章第一時間推送。
總結
以上是生活随笔為你收集整理的SuperPoint:深度学习特征点+描述子的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JAVA算法:哈希查找
- 下一篇: 人脸搜索引擎准得吓人,记者:我都不知道自