如何實現(xiàn)比?PyTorch?快?6?倍的?Permute/Transpose?算子?
無論是在統(tǒng)治NLP屆的Transformer,還是最近視覺領(lǐng)域的新秀Vision Transformer,我們都能在模型中看到Transpose/Permute算子的身影,特別是在多頭注意力機(jī)制(Multi-Head Attention)中,需要該算子來改變數(shù)據(jù)維度排布。
顯然,作為一個被高頻使用的算子,其CUDA實現(xiàn)會影響到實際網(wǎng)絡(luò)的訓(xùn)練速度。本文會介紹優(yōu)化Permute Kernel的技巧,并跟PyTorch的Permute,原生的Copy操作進(jìn)行實驗對比。1樸素的Permute實現(xiàn)
Permute算子的作用是變換張量數(shù)據(jù)維度的順序,舉個例子:
x = flow.randn(2, 3)
y = x.permute(1, 0)
y.shape
(3, 2)
其實現(xiàn)原理也可以很容易理解,即輸出Tensor的第i維,對應(yīng)輸入Tensor的dims[i]維,上述例子中 permute 實現(xiàn)對應(yīng)的偽代碼如下:
for row in x.shape[0]:
for col in x.shape[1]:
y[row][col] = x[col][row]
但是實際情況與上面的偽代碼有出入,張量的Shape是數(shù)學(xué)上的概念,在物理設(shè)備上并不真實存在。
張量的數(shù)據(jù)都是保存在一塊連續(xù)的內(nèi)存中,下圖分別從上層視角和底層視角描述了形狀為(2, 3)的張量的存儲方式:
Permute實現(xiàn)原理為:
-
通過當(dāng)前輸出的一維偏移量(offset)計算對應(yīng)的高維索引
-
然后根據(jù)參數(shù)dims重新排列輸出索引,進(jìn)而得到輸入索引。
-
將輸入索引轉(zhuǎn)換成輸入偏移量
-
最后進(jìn)行數(shù)據(jù)移動,整個過程的示意圖如下:
完成Permute后,輸出如下圖所示:
整個 permute 計算過程需要經(jīng)過多次一維偏移量offset和高維索引之間的轉(zhuǎn)換,為了避免一次次手工計算,提供了一個工具類NdIndexOffsetHelper來方便做上述轉(zhuǎn)換。
2NdIndexOffsetHelper
NdIndexOffsetHelper的主體方法如下:-
NdIndexToOffset方法把高維索引轉(zhuǎn)為一維偏移量
-
OffsetToNdIndex方法把一維偏移量轉(zhuǎn)為高維索引
有了這么一個工具類,那我們就可以很輕松的寫出一版Naive Permute Kernel了,核函數(shù)如下:template
__global__ void PermuteKernel(PermuteKernelParams params) {
using T = typename std::aligned_storage::type;
const T* src = reinterpret_cast(params.src);
T* dst = reinterpret_cast(params.dst);
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.dst_index_helper.OffsetToNdIndex(i, dst_index);
#pragma unroll
for (size_t dim = 0; dim (2, 3, 0, 1)
x = flow.randn(3, 4, 5, 6)
y = x.permute(2, 3, 0, 1)
y.shape
(5, 6, 3, 4)
顯然這是一個四維的Permute情形,但這里第2,3維,第0,1維是一起Permute的,所以我們可以看成是一種二維的Permute情形:
# (0, 1, 2, 3) -> ((2, 3), (0, 1))
x = x.reshape(x.shape[0]*x.shape[1], x.shape[2]*x.shape[3])
y = x.permute(1, 0)
y = y.reshape(x.shape[2], x.shape[3], x.shape[0], x.shape[1])
合并維度后,在利用NdIndexOffsetHelper根據(jù)偏移量計算索引時,合并前需要計算成四維索引,而合并后我們只需計算成二維索引。相比合并前減少除法和乘法的次數(shù),進(jìn)而提升速度。
3. 使用更大的訪問粒度
細(xì)心的朋友們可能觀察到核函數(shù)中有一個模板參數(shù)size_t movement_size,它表示的是訪問元素的粒度。在Nvidia性能優(yōu)化博客increase Performance with Vectorized Memory Access中提到可以通過向量化內(nèi)存操作來提高CUDA Kernel性能,能夠減少指令數(shù),提高帶寬利用率。(鏈接:https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/)
我們設(shè)置訪問粒度的規(guī)則如下:
-
CUDA支持的訪問粒度為1B,2B,4B,8B,16B,粒度越大性能越好
-
最后一個維度是作為整體來移動的,即permutation[n-1]==x.dims[n-1],且大小是新訪問粒度的倍數(shù)
-
保證數(shù)據(jù)指針滿足新訪問粒度的對齊要求
針對規(guī)則2,對應(yīng)著以下Permute場景:(0, 1, 2, 3) -> (0, 2, 1, 3)其中最后一維并沒有變化,僅僅是第1,2維進(jìn)行交換,那么我們可以使用更大的訪問粒度來讀取數(shù)據(jù),再進(jìn)行Permute操作。代碼中通過GetMovementSize函數(shù)來確定訪問粒度的大小。
我們使用Nsight Compute對PyTorch的Permute和原生Copy操作對比測試運行時間和帶寬,測試結(jié)果如下:
其中測試環(huán)境為NVIDIA A100 40GB,場景為(0, 1, 2)->(1, 0, 2),橫坐標(biāo)表示數(shù)據(jù)形狀及數(shù)據(jù)類型。測試數(shù)據(jù)覆蓋了16MB到128MB不同大小的數(shù)據(jù),數(shù)據(jù)類型包含fp32和half兩種類型。
從上面兩張圖可以看到,在大部分情況下都可以逼近甚至略高于Copy操作的帶寬。與PyTorch對比,在操作耗時上最少快1.24倍,最快能達(dá)1.4倍。這里Permute的帶寬比原生Copy還高一點,是因為Copy Kernel里沒有做unroll指令間并行優(yōu)化,而Permute Kernel內(nèi)部做了相關(guān)優(yōu)化,這里僅做參考。使用上面的兩個優(yōu)化技巧,就能輕易做到比PyTorch的實現(xiàn)要快了。常規(guī)的Permute適用情況比較廣泛,也因此可能存在訪存不合并的情況。在一些特殊的場景下,我們可以通過合并訪存以提升帶寬利用率和速度,這就引出我們下個關(guān)于BatchTranspose優(yōu)化的話題。
4BatchTranspose優(yōu)化
BatchTranspose操作即矩陣轉(zhuǎn)置,僅交換矩陣最后的兩維,以下情況均符合BatchTranspose的定義,其中括號內(nèi)容表示維度的順序:
(0, 1) -> (1, 0)
(0, 1, 2) -> (0, 2, 1)
在樸素的Permute方案中,對于最后一維作為整體移動的情況下,已經(jīng)進(jìn)行充分的優(yōu)化。但實際場景中還存在矩陣轉(zhuǎn)置的情況,此時無法應(yīng)用第三條增大訪問粒度的優(yōu)化操作,并且不滿足訪存合并要求,導(dǎo)致性能不佳。以Pytorch為例,在數(shù)據(jù)大小為128MB情況下進(jìn)行BatchTranspose時,因為未合并的訪存導(dǎo)致實際讀取數(shù)據(jù)量遠(yuǎn)大于寫入數(shù)據(jù)量(7-8倍)。
在英偉達(dá)性能優(yōu)化博客An Efficient Matrix Transpose in CUDA C/C (https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/)中,其做法是設(shè)置一塊Shared Memory,然后將一行數(shù)據(jù)讀取到Shared Memory,再按列順序?qū)hared Memory中的元素寫回到Global Memory中。得益于Shared Memory訪問粒度小的特性(Global Memory是32B,Shared Memory是4B),進(jìn)而避免Global Memory的訪存不連續(xù)的問題。
Shared Memory相比Global Memory有15倍更高的帶寬,20-40倍更低的延遲,因此額外引入的讀寫開銷可以忽略不計。
此外我們給Shared Memory多padding了一個元素,進(jìn)而讓以列順序訪問的元素能夠均勻分布在32個bank上,避免bank conflict。對應(yīng)的示意圖如下(其中灰色部分代表Padding元素):
template
__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType H, IndexType W,
IndexType num_tile_rows, IndexType num_tile_cols,
int32_t block_nums) {
using T = typename std::aligned_storage::type;
__shared__ T tile[tile_size][tile_size 1]; // To avoid bank conflict.
const T* src = reinterpret_cast(src_ptr);
T* dst = reinterpret_cast(dst_ptr);
IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i