从零开始写Qwen3(四)实现RMSNorm算子

张开发
2026/4/13 11:18:02 15 分钟阅读

分享文章

从零开始写Qwen3(四)实现RMSNorm算子
从零开始写Qwen3目录1. 概述已经搭建了基本模型可以推理并且应用了KVCache现在我们可以开始手写算子先从最简单的RMSNorm开始1.1 什么是 RMS Norm?每个特征向量除以方均根进行归一化再乘以一个 gamma 进行尺度缩放RMSNorm ( x ) i , j x i , j ∑ j x i , j 2 n ϵ ⋅ γ j \text{RMSNorm}(x)_{i,j} \frac{x_{i,j}}{\sqrt{\frac{\sum_j x_{i,j}^2}{n} \epsilon}}\cdot \gamma_jRMSNorm(x)i,j​n∑j​xi,j2​​ϵ​xi,j​​⋅γj​2. 环境搭建给torch写算子有两种方法直接在setup.py中写CUDAExtension/CPPExtension然后python安装的时候会调用ninja类似make的工具编译cpp/cu文件生成一个so文件自己写CMakeLists.txt自己配置pythonpytorch等依赖然后生成so文件通过cmake的软链接到对应目录下前者很方便不用自己找torch、python、cuda等依赖但缺点就是编译非常慢它和python代码绑在一起每次编译都需要刷新python的依赖这里选择CMake项目它编译快灵活度高整个项目配置一遍就不用管了2.2 项目结构qwen3_from_scratch/├── kernels/│ ├── rms_norm/│ │ ├── rms_norm.cpp # CPU 实现Python 入口 │ │ └── rms_norm.cu # CUDA 实现 │ └── kernels.h ├── pybind11.cpp # 模块注册 ├── CMakeLists.txt └── cmake/└── find_pytorch_vars.cmakepybind11.cpp负责注册模块和所有的函数一个目录一个算子.cpp负责cpu实现同时兼任算子入口.cu负责GPU实现2.3 CMake配置要点使用 Python 获取 Torch/Python/CUDA 的 cmake 路径设置 USE_CUDA 宏区分有/无 CUDA 环境链接 torch_python 库解决符号未定义问题编译后软链接到 Python 目录详细代码可以参考代码仓中的CMakeLists.txt3. 模块注册在py11bind.cpp中写#includekernel.hPYBIND11_MODULE(TORCH_EXTENSION_NAME,m){m.doc()qwen3_from_scratch kernels;m.def(rms_norm_forward,rms_norm_forward,RMSNorm forward computation (CPU/CUDA),py::arg(x),py::arg(gamma),py::arg(eps)1e-6f);}kernels.h中包含了torch/extensions.h这里注册了一个叫做ops的模块导出了一个叫做rms_norm_forward的函数包含三个参数4. CPU实现cpu实现都比较简单不做优化说明原理验证准确性就行for(inti0;iseqLen;i){constT*currentXxi*hiddenDimStride;T*currentOutputoutputi*hiddenDim;// 1. 计算当前序列位置的平方和floatsumSq0.0f;for(intk0;khiddenDim;k){constfloatvalstatic_castfloat(currentX[k]);sumSqval*val;}// 2. 计算均方根RMSconstfloatrmssqrtf(sumSq/static_castfloat(hiddenDim)eps);// 3. 归一化 缩放for(intj0;jhiddenDim;j){constfloatvalstatic_castfloat(currentX[j]);currentOutput[j]static_castT(val/rms*static_castfloat(gamma[j]));}}简单循环就行外部把B、H、S合并成一个也就是 reshape(-1, D)到这里就成为二维数据5. CUDA实现5.1 CUDA 并行层级回顾线程束(warp)32个线程同时执行相同指令但每个线程有自己的寄存器可以读取不同指令也就是SIMD单指令多数据块(block)块包含多个线程束最多1024个线程块是调度基本单位同时也是共享内存和线程同步的基本单位最好不要跨块通信网格(grid)块就是网格上的一个点一次函数执行会把网格上所有点都执行完5.2 计算步骤求方均根需要归约每个元素除以方均根完全并行乘以 gamma完全并行关键就在于步骤1的规约操作5.3 规约优化策略5.3.1 分层规约首先根据线程块长度将向量分为多个部分然后求和每个线程一个于是长度为N的向量变成了长度为整块线程数量的结果for(uint32_titid;ihiddenDim;iblockSize){sumSqtempX[tempXj]*tempX[tempXj];}这样块中每个线程拥有一个变量sumSq总共blockSize这么多个线程然后下一步就要跨线程汇总需要用到共享内存把每个线程的值存进去开始下一步汇总跨线程汇总最典型的模式就是这样for(intwwidth/2;w0;w/2){if(tidw){smem[tid]smem[tidw];}__syncthreads();}这样每次把一半的元素累加最后得到只剩一个数。这种方法可以让参与计算的线程尽可能都在前几个warp中而不是分散在各个warp中可以降低线程分化5.3.2 线程束打散规约由于可能不止涉及一个线程束必须等待所有线程束计算完毕会产生大量等待消耗cuda针对这种操作专门提供了线程束内的同步函数使用这种函数就不需要线程束之间的同步而且巧妙的是一个线程块至多1024个线程正好是32*32所以进行一次线程束汇总后只剩下至多32个数据只用放到一个线程束中执行一次线程束汇总就行只需要一次同步线程束内同步函数通常叫做shfl_xxx_sync大致使用就是shlf_xxx_sync(val,mask,offset)val是单个变量这就相当于是把整个线程束的val当成一个数组第i个线程(i叫做 laneId )对应val[id]if(((1i)mask)(xxx(i,offset)0xxx(i,offset)32){val[xxx(i,offset)]}常见的有shfl_down_sync就是xxx(i, offset)ioffsetshlf_up_sync类似shlf_xor_sync就是xxx(i,offset)i^offset可以使用这个函数对32个线程进行无同步汇总templateintwidthWARP_SIZE,typenameT__device__ __forceinline__ Twarp_reduce_sum(T x){#pragmaunrollfor(intoffsetwidth/2;offset0;offset1){x__shfl_xor_sync(0xffffffff,x,offset,width);}returnx;}由于width已知所以可以把循环展开成多条5条指令把分支判断和跳转都去掉为了方便展示假设线程束只有8个线程执行结果如下初始[a0,a1,a2,a3,a4,a5,a6,a7]step1[(a0a4),(a1a5),(a2a6),(a3a7),(a4a0),(a5a1),(a6a2),(a7a3)]step2[(a0a4a2a6),(a1a5a3a7),(a2a6a0a4),(a3a7a1a5),(a4a0a6a2),(a5a1a7a3),(a6a2a4a0),(a7a3a5a1)]step3[(a0a4a2a6a1a5a3a7),(a1a5a3a7a0a4a2a6),(a2a6a0a4a3a7a1a5),(a3a7a1a5a2a6a0a4),(a4a0a6a2a5a1a7a3),(a5a1a7a3a4a0a6a2),(a6a2a4a0a7a3a5a1),(a7a3a5a1a6a2a4a0)]可以看到所有线程最后结果都变成一样而down和up不是。一般这种操作只会输出线程0的结果这样计算出来剩下不到32个有效值用共享内存同步一次再执行一次就行__shared__ T s_sum[32];constuint32_twarpIdtid/WARP_SIZE;constuint32_tlaneIdtid%WARP_SIZE;if(laneId0){s_sum[warpId]sumSq;}__syncthreads();sumSq0.0f;if(laneId(blockSize/WARP_SIZE)){sumSqs_sum[laneId];}sumSqwarp_reduce_sum(sumSq);这里使用xor的好处就来了xor版的warp_reduce_sum中所有线程拿到相同的值就省去再获取一次值5.4 完整流程templatesize_t blockSize,size_t hiddenDim__global__voidrms_norm_kernel_arr(constfloat*__restrict__ x,float*__restrict__ output,constfloat*__restrict__ gamma,constintseqLen,constinthiddenDimStride,constfloateps){uint32_ttidthreadIdx.x;uint32_tblockIdblockIdx.x;constT*x_ptrxblockId*hiddenDimStride;// 如果是连续hiddenDimStride就是 hiddenDimoutputblockId*hiddenDim/* *1 output是刚申请的stride肯定是1*/;floatsumSq0.0f;#pragmaunrollfor(uint32_titid;ihiddenDim;iblockSize){floatxx_ptr[i];sumSqx*x;}sumSqwarp_reduce_sum(sumSq);ifconstexpr(blockSizeWARP_SIZE){static_assert((blockSize1024)(blockSize%WARP_SIZE0),blockSize must be a multiple of warpSize);__shared__floats_sum[32];constuint32_twarpIdtid/WARP_SIZE;constuint32_tlaneIdtid%WARP_SIZE;if(laneId0){s_sum[warpId]sumSq;}__syncthreads();sumSq0.0f;if(laneId(blockSize/WARP_SIZE)){sumSqs_sum[laneId];}sumSqwarp_reduce_sum(sumSq);}constfloatmeansumSq/hiddenDim;constfloatscalersqrtf(meaneps);#pragmaunrollfor(uint32_titid;ihiddenDim;iblockSize){floatxx_ptr[i];output[i]static_castT(x*scale*gamma[i]);}}全程只需要一个核函数就可以解决6. 性能测试与对比6.1 测试环境GPU: RTX 3060对比基准: PyTorch 2.10 nn.functional.rms_norm测试数据: B×128×1024多种数据类型6.2 性能结果数据类型平均加速比最佳加速比bfloat161.29x1.66x (Dim128)float161.20x1.50x (Dim16384)float321.29x1.67x (Dim64)详情见 rms_norm_benchmark_report.md6.3 torch融合算子对比如果在torch2.8及以前的版本测试这个例子会发现提升更加明显甚至可以到几倍因为torch2.8之前nn.functional.rms_norm算子没有融合是多个算子接连计算性能大打折扣可以看2.8版本而2.10版本的2.8版本是先调用pow(2)然后求均值然后相加再做除法而2.10只有一个函数6.4 CUDA加速总结融合操作减少内存访问次数Warp shuffle 比共享内存归约更高效减少 Python 层调用开销

更多文章