对 torch 中 dim 的总结和理解
生活随笔
收集整理的這篇文章主要介紹了
对 torch 中 dim 的总结和理解
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
pytorch 中,使用到 dim 參數的 api 都是跟集合有關的,比如 max(), min(), mean(), softmax() 等。當指定某個 dim 時,表示使用該維度的所有元素進行集合運算,一個 tensor 的 shape 為 (3, 4, 5),分別對應的 dim 如下所示
| 0 | 3 |
| 1 | 4 |
| 2 | 5 |
當使用 max(dim=1) 時,表示使用第二個維度中全部四個元素中的每個元素參與求最大值計算,計算后的 shape 變為 (3,5),因為只從 四個中求得最大的那個作為結果。如果 shape 的長度為 3,則 dim 的取值只能在區間 [-3, 2],否則將報錯。
Example
>>> a = torch.randn(3,4,5) # 求得第二個維度的最大值 >>> torch.max(a,1) torch.return_types.max( values=tensor([[0.7700, 0.1390, 0.6952, 1.9428, 0.8477],[1.0085, 0.7961, 0.9462, 2.1287, 0.9356],[1.1520, 2.1478, 0.8291, 1.0854, 0.7780]]), indices=tensor([[1, 1, 2, 2, 0],[1, 2, 2, 3, 0],[0, 1, 3, 3, 3]]))# 第二個維度縮減為只有一個元素,即 (3,1,5),api 將維度為 1 的去掉了 >>> torch.max(a,1).values.shape torch.Size([3, 5])# 第三個維度縮減為只有一個元素,即 (3,4,1),api 將維度為 1 的去掉了 >>> torch.max(a,2).values.shape torch.Size([3, 4])# 超出 dim 范圍,報錯 >>> torch.max(a,3).values.shape Traceback (most recent call last):File "<stdin>", line 1, in <module> IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)總結:
1、dim 是一種集合運算的參數,表示將某個維度的所有元素參與集合運算
2、dim 的取值和 shape 的長度密切相關,dim 的取值為 [-len(shape), len(shape)-1]
總結
以上是生活随笔為你收集整理的对 torch 中 dim 的总结和理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 揭秘淘宝搜索量快速暴增的秘密
- 下一篇: Coursera奖学金申请模板