Pytorch中tensor维度和torch.max()函数中dim参数的理解
Pytorch中tensor維度和torch.max()函數中dim參數的理解
維度
參考了 https://blog.csdn.net/qq_41375609/article/details/106078474 ,
對于torch中定義的張量,感覺上跟矩陣類似,不過常見的矩陣是二維的。當定義一個多維的張量時,比如使用 a =torch.randn(2, 3, 4) 創建一個三維的張量,返回的是一個
[[[-0.5166, 0.8298, 2.4580, -1.9504],[ 0.1119, -0.3321, -1.3478, -1.9198],[ 0.0522, -0.6053, 0.8119, -1.3469]],[[-0.3774, 0.9283, 0.7996, -0.3882],[-1.1077, 1.0664, 0.1263, -1.0631],[-0.9061, 1.0081, -1.2769, 0.1035]]
]
當使用 a.size() 返回維度結果時,結果為 torch.Size([2, 3, 4]),這里面有三個數值,數值的個數代表維度的個數 ,所以這里有三個維度,可以理解為一個中括號代表一個維度。數值 2 處在第一個位置,第一個位置代表是第一維度,2代表這個維度有兩個元素,也就是第一個 [ ] 里面兩個元素,3代表在第二個維度,也就是在第一個 [ ] 中的兩個元素里面,又有三個元素,依次類推。這里格式十分固定,一旦定義,必須是一個元素里面有兩個元素,這兩個元素中每個再包含三個元素,再包含,依此類推,否則會報錯。類似與樹,維數等于相似的樹的深度-1(以根為第一層),每一層就是一維。
如生成一個
torch.tensor([[[1, 2, 3, 4][3, 4, 2, 1][4, 1, 2, 3]][[2, 1, 3, 4][3, 4, 2, 1][4, 1, 2, 3]]]
)
方便理解,以下圖的形式展示,這里豎線代表一個維度,豎線上所有節點代表同一維度的所有元素。在下面所有圖中,同顏色的元素都是按照從上往下按順序排列的。
一、dim參數
在使用torch.max()函數和其他的一些函數時,會有dim這個參數。官網中定義使用torch.max()函數時,生成的張量維度會比原來的維度減少一維,除非原來的張量只有一維了. 要減少消去的是哪一維便是由dim參數決定的,dim參數實際上指的是我們計算過程中所要消去的維度。因為在比較時必須要指定使用哪些數字來比較 ,或者進行其他計算,比如 max 使一些數據中只要大的,sum只取和的結果,自然就會刪減其他的一些數據從而引起降維。
以上面生成的三維的張量為例子,有三個維度,但是維度的數字順序是 dim = 0, 1, 2;
當指定torch.max(a,dim=0)時,也就是要刪除第一個維度,刪除第一個維度的話,那還剩下兩個維度,也就是dim =1 ,2 。 剩下的兩個維度的參數是 3 和 4,那么刪除第一個維度后應該剩下torch.tensor(3, 4)這樣形式的張量, dim參數可以使用負數,也就是負的索引,與列表中的索引相似,在本例中dim = -1 與dim = 2是一樣的。
返回的
values=tensor([[-0.3774, 0.9283, 2.4580, -0.3882],[ 0.1119, 1.0664, 0.1263, -1.0631],[ 0.0522, 1.0081, 0.8119, 0.1035]]),
indices=tensor([[1, 1, 0, 1], [0, 1, 1, 1],[0, 1, 0, 1]]))
從返回的結果看是這種形式,產生這種結果是因為刪除了第一個維度那么該返回 3 * 4 這種二維的張量,第一維中兩個元素的形式正好是 3 * 4, 那么就將這個維度的兩個子元素中的相應的位置的值比較一下大小,那么會生成一個新的 3 * 4 的張量,再返回一下正好可以,indices記錄的是 "在比較中勝利的元素“ 原來所屬的元素的位置。例如在第一個位置上,-0.3774比 -0.5166大,所以返回-0.3774,-0.3774是在第一維度里面的第二個元素的位置上,這個位置索引為1.剩下的位置的同理。
用樹狀圖理解
圖中的不同顏色的三個子元素,在相同位置比較,大的返回形成新的元素,其他位置同理。那么黑色的維度 dim = 1 也就消除了.
dim = 0時,如圖,兩個3*4的子元素張量 相對應的位置 比較大小,剩下一個3 * 4的二維張量
當dim = 2或者 dim = -1,刪除的是最后一個維度,在這個例子中嗎,將所有的第三維的子元素最大的值返回,返回2 * 3,看起來就像是找所在矩陣一行里面的最大值一樣。
values=tensor([[2.4580, 0.1119, 0.8119],[0.9283, 1.0664, 1.0081]]),
indices=tensor([[2, 0, 2],[1, 1, 1]]))
舉一個sum()例子,當使用上述使用torch.sum(a,dim = 1),消去第二個維度,剩下一,三維度,也就是2 * 4形狀的張量。將第二維上面的三個子元素相同位置的相加,第二維也就不見了,第一維中的兩個元素的子元素就從3*4形成了一個1 *4的,總的形狀就變成了2 * 4
tensor([[-0.3525, -0.1076, 1.9221, -5.2171],[-2.3912, 3.0028, -0.3510, -1.3478]])
再舉一個例子,使用torch.randn(2, 3, 4, 5) 創建一個四維張量,使用torch.max(dim=-3),也就是torch.max(dim=1)
torch.tensor([[[[ 0.7106, 1.3332, -1.0423, -0.1609, -0.2846],[ 0.6400, 2.2507, -0.5740, -0.9986, 0.0066],[-0.0527, 1.4097, -0.4439, 0.4846, 1.5418],[ 1.0027, 0.9398, 1.5202, -1.1660, -0.1230]],[[ 0.5725, -1.7838, -0.7320, -1.4419, 1.5762],[ 0.6407, 0.0527, 1.7005, 1.6350, -0.2610],[ 1.3307, -0.3210, -1.7203, 0.9050, 0.2442],[ 0.9418, -0.1511, 0.8248, -0.0786, -0.6153]],[[ 1.0182, 0.3190, -0.3408, -2.1801, -0.3931],[ 1.2325, -0.3304, 1.0116, 0.0791, -1.1174],[ 0.2331, -0.9062, 0.5680, 1.6061, -1.0933],[ 0.6935, -0.5140, -0.5178, 1.2557, 0.2319]]],[[[ 1.0916, 0.7171, -0.7936, 1.1741, -0.5457],[-0.6541, -0.6720, -0.7892, -0.6961, -1.1030],[ 1.8680, -0.1746, 0.8455, -1.1021, 0.6855],[ 1.2070, -0.6152, -1.3345, -0.0724, 1.2062]],[[-0.5130, -0.5510, -0.8278, -0.2279, -1.4425],[ 0.2073, 1.3065, -0.0326, -1.2566, 0.6097],[-1.0413, 1.2638, -0.8479, -0.0353, -0.7191],[ 0.0662, 0.7683, 0.2145, -0.0988, -2.3348]],[[ 0.6631, -0.0040, -0.0681, 1.1681, 1.3904],[-0.1761, 1.4668, 0.9670, -0.5629, 0.2941],[-0.6235, 0.1844, -0.4321, -0.0581, -0.9352],[ 0.1717, -0.9188, 0.3014, -0.0734, -0.1324]]]])
在這里面,當dim = 1,也就是要動第二個維度手,那么刪掉它后剩下torch.randn(2,4, 5)形式,那么就
[[ 0.7106, 1.3332, -1.0423, -0.1609, -0.2846],
[ 0.6400, 2.2507, -0.5740, -0.9986, 0.0066],
[-0.0527, 1.4097, -0.4439, 0.4846, 1.5418],
[ 1.0027, 0.9398, 1.5202, -1.1660, -0.1230]]
和
[[ 0.5725, -1.7838, -0.7320, -1.4419, 1.5762],
[ 0.6407, 0.0527, 1.7005, 1.6350, -0.2610],
[ 1.3307, -0.3210, -1.7203, 0.9050, 0.2442],
[ 0.9418, -0.1511, 0.8248, -0.0786, -0.6153]]
還有
[[ 1.0182, 0.3190, -0.3408, -2.1801, -0.3931],
[ 1.2325, -0.3304, 1.0116, 0.0791, -1.1174],
[ 0.2331, -0.9062, 0.5680, 1.6061, -1.0933],
[ 0.6935, -0.5140, -0.5178, 1.2557, 0.2319]]
這三個子元素相應為位置比較大小,大的留下,生成新的張量,列如對于第一個位置,1.0182 比 0.5725 和 0.7106 大,所以它留下,它在元素在要是動手的維度里面的位置索引為2,其它同理
但是這個維度還之前還有一個維度,那么只要對所有的同維度的做相同操作就可以了,所以返回之如下
values=tensor([[[ 1.0182, 1.3332, -0.3408, -0.1609, 1.5762],[ 1.2325, 2.2507, 1.7005, 1.6350, 0.0066],[ 1.3307, 1.4097, 0.5680, 1.6061, 1.5418],[ 1.0027, 0.9398, 1.5202, 1.2557, 0.2319]],[[ 1.0916, 0.7171, -0.0681, 1.1741, 1.3904],[ 0.2073, 1.4668, 0.9670, -0.5629, 0.6097],[ 1.8680, 1.2638, 0.8455, -0.0353, 0.6855],[ 1.2070, 0.7683, 0.3014, -0.0724, 1.2062]]]),
indices=tensor([[[2, 0, 2, 0, 1],[2, 0, 1, 1, 0],[1, 0, 2, 2, 0],[0, 0, 0, 2, 2]],[[0, 0, 2, 0, 2],[1, 2, 2, 2, 1],[0, 1, 0, 1, 0],[0, 1, 2, 0, 0]]]))
總結
以上是生活随笔為你收集整理的Pytorch中tensor维度和torch.max()函数中dim参数的理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: keil5安装教程及下载
- 下一篇: 荣耀升级android版本最好用,到底好