元算子卷积层实现
元算子卷積層實現(xiàn)
元算子是jittor的關(guān)鍵概念,元算子的層次結(jié)構(gòu)如下所示。
元算子由重索引算子,重索引化簡算子和元素級算子組成。重索引算子,重索引化簡算子都是一元算子。 重索引算子是其輸入和輸出之間的一對多映射。重索引簡化算子是多對一映射。廣播,填補, 切分算子是常見的重新索引算子。 而化簡,累乘,累加算子是常見的索引化簡算子。元素級算子是元算子的第三部分,與前兩個相比,元素算級子可能包含多個輸入。元素級算子的所有輸入和輸出形狀必須相同,它們是一對一映射的。 例如,兩個變量的加法是一個二進制的逐元素算子。
元算子的層級結(jié)構(gòu)。元算子包含三類算子,重索引算子,重索引化簡算子,元素級算子。元算子的反向傳播算子還是元算子。元算子可以組成常用的深度學(xué)習(xí)算子。而這些深度學(xué)習(xí)算子又可以進一步組成深度學(xué)習(xí)模型。
上面演示了如何通過三個元算子實現(xiàn)矩陣乘法:
def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
a = a.broadcast([n,m,k], dims=[2])
b = b.broadcast([n,m,k], dims=[0])
return (a*b).sum(dim=1)
本文將展示如何使用元算子實現(xiàn)卷積。
首先,實現(xiàn)一個樸素的Python卷積:
import numpy as np
import os
def conv_naive(x, w):
N,H,W,C = x.shape
Kh, Kw, _C, Kc = w.shape
assert C==_C, (x.shape, w.shape)
y = np.zeros([N,H-Kh+1,W-Kw+1,Kc])
for i0 in range(N):for i1 in range(H-Kh+1): for i2 in range(W-Kw+1):for i3 in range(Kh):for i4 in range(Kw):for i5 in range(C):for i6 in range(Kc):if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continuey[i0, i1, i2, i6] += x[i0, i1 + i3, i2 + i4, i5] * w[i3,i4,i5,i6]
return y
下載一個貓的圖像,并使用conv_naive實現(xiàn)一個簡單的水平濾波器。
%matplotlib inline
import pylab as pl
img_path="/tmp/cat.jpg"
if not os.path.isfile(img_path):
!wget -O - ‘https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg/220px-Felis_silvestris_catus_lying_on_rice_straw.jpg’ > $img_path
img = pl.imread(img_path)
pl.subplot(121)
pl.imshow(img)
kernel = np.array([
[-1, -1, -1],
[0, 0, 0],
[1, 1, 1],
])
pl.subplot(122)
x = img[np.newaxis,:,:,:1].astype(“float32”)
w = kernel[:,:,np.newaxis,np.newaxis].astype(“float32”)
y = conv_naive(x, w)
print (x.shape, y.shape) # shape exists confusion
pl.imshow(y[0,:,:,0])
naive_conv運作良好。用jittor替換樸素實現(xiàn)。
import jittor as jt
def conv(x, w):
N,H,W,C = x.shape
Kh, Kw, _C, Kc = w.shape
assert C==_C
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
‘i0’, # Nid
‘i1+i3’, # Hid+Khid
‘i2+i4’, # Wid+KWid
‘i5’, # Cid|
])
ww = w.broadcast_var(xx)
yy = xx*ww
y = yy.sum([3,4,5]) # Kh, Kw, c
return y
Let’s disable tuner. This will cause jittor not to use mkl for convolution
jt.flags.enable_tuner = 0
jx = jt.array(x)
jw = jt.array(w)
jy = conv(jx, jw).fetch_sync()
print (jx.shape, jy.shape)
pl.imshow(jy[0,:,:,0])
結(jié)果看起來一樣。性能如何?
%time y = conv_naive(x, w)
%time jy = conv(jx, jw).fetch_sync()
可以看出jittor的實現(xiàn)要快得多。為什么這兩個實現(xiàn)在數(shù)學(xué)上等效,而jittor的實現(xiàn)運行速度更快?將逐步進行解釋:
首先,看一下jt.reindex的幫助文檔。
help(jt.reindex)
可以擴展重索引操作,以便更好地理解:
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
‘i0’, # Nid
‘i1+i3’, # Hid+Khid
‘i2+i4’, # Wid+KWid
‘i5’, # Cid
])
ww = w.broadcast_var(xx)
yy = xx*ww
y = yy.sum([3,4,5]) # Kh, Kw, c
擴展后:
shape = [N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc]
expansion of x.reindex
xx = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
if is_overflow(i0,i1,i2,i3,i4,i5,i6):
xx[i0,i1,…,in] = 0
else:
xx[i0,i1,i2,i3,i4,i5,i6] = x[i0,i1+i3,i2+i4,i5]
expansion of w.broadcast_var(xx)
ww = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
ww[i0,i1,i2,i3,i4,i5,i6] = w[i3,i4,i5,i6]
expansion of xx*ww
yy = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
yy[i0,i1,i2,i3,i4,i5,i6] = xx[i0,i1,i2,i3,i4,i5,i6] * ww[i0,i1,i2,i3,i4,i5,i6]
expansion of yy.sum([3,4,5])
shape2 = [N,H-Kh+1,W-Kw+1,Kc]
y = np.zeros(shape2, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6]
循環(huán)融合后:
shape2 = [N,H-Kh+1,W-Kw+1,Kc]
y = np.zeros(shape2, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
if not is_overflow(i0,i1,i2,i3,i4,i5,i6):
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6]
這是就元算子的優(yōu)化技巧,它可以將多個算子融合為一個復(fù)雜的融合算子,包括許多卷積的變化(例如group conv,separate conv等)。
jittor會嘗試將融合算子優(yōu)化得盡可能快。嘗試一些優(yōu)化(將形狀作為常量編譯到內(nèi)核中),并編譯到底層的c++內(nèi)核代碼中。
jt.flags.compile_options={“compile_shapes”:1}
with jt.profile_scope() as report:
jy = conv(jx, jw).fetch_sync()
jt.flags.compile_options={}
print(f"Time: {float(report[1][4])/1e6}ms")
with open(report[1][1], ‘r’) as f:
print(f.read())
比之前的實現(xiàn)還要更快! 從輸出中,可以看一看func0的函數(shù)定義,這是卷積內(nèi)核的主要代碼,該內(nèi)核代碼是即時生成的。因為編譯器知道內(nèi)核的形狀,所以使用了更多的優(yōu)化方法。
Jittor簡單演示了元算子的使用,并不是真正的性能測試,所以使用了比較小的數(shù)據(jù)規(guī)模進行測試,如果需要性能測試,打開jt.flags.enable_tuner = 1,會啟動使用專門的硬件庫加速。
總結(jié)
- 上一篇: Caffe实现概述
- 下一篇: Jittor 的Op, Var算子