怎样运营网站,阿里云自助建站模板,wordpress移动端页面模板下载地址,搭建网站需要程序带注释版本的实现被写到了这里#xff1a;https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/apex 由于有很多个人理解#xff0c;读者可配合当前文章谨慎理解。 0x0. 背景
我也是偶然在知乎的一个问题下看到这个问题#xff0c;大概就是说在使用apex的… 带注释版本的实现被写到了这里https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/apex 由于有很多个人理解读者可配合当前文章谨慎理解。 0x0. 背景
我也是偶然在知乎的一个问题下看到这个问题大概就是说在使用apex的LayerNorm/RMSNorm的时候可以打开这个api的memory_efficient开关这个开关可以在速度和精度无损的情况下节省网络训练的显存占用。感觉比较有趣我就研究了一下因此也就有了这篇文章。
我去实测了一下单机8卡A100训练LLama7B纯数据并行的情况下打开memory_efficient开关相比于不打开节省了大约2个G的显存如果模型继续scale up那么省掉的显存也会更多。因此本文就是对这个memory_efficient开关的背后实现做一个解读另外也会对apex里面LayerNorm/RMSNorm本身的cuda kernel实现做一个细节解读。
apex的LayerNorm/RMSNorm被实现成一个fuse kernel然后上层使用torch.autograd.Function来封装本文的讲解主要以LayerNorm为例子入口见https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm.py#L32-L51 。实际上RMSNorm和LayerNorm的实现是共享的只不过在kernel内部会区分一下缩放策略是2个参数LayerNorm的gamma和beta还是一个参数。
class FusedLayerNormAffineFunction(torch.autograd.Function):staticmethoddef forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficientFalse):global fused_layer_norm_cudaif fused_layer_norm_cuda is None:fused_layer_norm_cuda importlib.import_module(fused_layer_norm_cuda)ctx.normalized_shape normalized_shapectx.eps epsctx.memory_efficient memory_efficientinput_ input.contiguous()weight_ weight.contiguous()bias_ bias.contiguous()output, mean, invvar fused_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)if ctx.memory_efficient:ctx.save_for_backward(output, weight_, bias_, None, invvar)else:ctx.save_for_backward(input_, weight_, bias_, mean, invvar)return output可以看到在非memory_efficient模式下面ctx.save_for_backward(output, weight_, bias_, None, invvar)保存了用于backward的tensor包括输入权重偏置均值和方差的逆。但在memory_efficient模式下面ctx.save_for_backward(output, weight_, bias_, None, invvar)则是保存了输出权重偏置以及方差的逆。
这个地方看下你是否会掉入误区从表面上看这里也就只省掉了一个gamma因为输入和输出tensor的形状是一样的那么这样还有什么收益呢背景是在pre-ln的transformer架构里面LayerNorm/RMSNorm之后紧接着是一个线性投影无论是在注意力机制还是在多层感知机mlp中都是如此所以输出Tensor一定要被保存下来。而在post-ln架构中输出还会直接用于残差连接。然而在这两种情况下LayerNorm/RMSNorm的输入都不再被使用所以这里原本的输入保存变得相当多余因为我们可以保存无论如何都会被保存的输出张量。这样就可以达到节省显存的目的了。
接下来就详细解读下实现。
0x1. Apex的LayerNorm前向cuda实现
https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda.cpp 这个文件是基于实现的LayerNorm cuda kernel使用torch extension模块导出python接口。
同时这个文件还写了几个工具函数比如compute_n1_n2用来计算LayerNorm中非归一化和归一化部分的大小https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L7C31-L7C51 check_args函数对LayerNorm的参数进行检查https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L32C22-L143 。
此外这个cpp预定义了cuda_layer_norm的函数接口并且考虑了gamma/beta是否为空。
接下来就正式对LayerNorm的前向cuda实现进行解析。
0x1.1 工具函数
LayerNorm使用Welford算法统计均值方差在 https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu 写了一系列kernel实现中需要用到的工具函数这些函数是gpu上用到的。下面对其简单解析一下另外Welford算法可以看这篇博客的介绍用Welford算法实现LN的方差更新感叹一下zzk写这篇文章的时候还是萌新经过2年时间已经成长为国内顶级的工程师了开挂般学习能力 。工具函数包含cuWelfordOnlineSumcuChanOnlineSumcuRMSOnlineSumcuChanRMSOnlineSum这些我把自己的原始注释使用gpt4进行了润色这样会显得更加通俗一些。具体解释如下
// 这段代码是个CUDA函数名叫cuWelfordOnlineSum擅长用Welford算法来边收数据边算这些数据的平均值和变化范围就是均值和方差。
// 用Welford算法来算这个特别稳不会因为数据太多而出错而且每加一个数据就能更新一次均值和方差。
// const U curr: 这个是新来的数据点。
// U mu: 这个是我们到现在为止算出来的所有数据的平均值。
// U sigma2: 这个是我们到现在为止算出来的方差可以告诉你数据变化有多大。
// U count: 这个记录了我们到现在处理了多少数据点。templatetypename U __device__
void cuWelfordOnlineSum(const U curr,U mu,U sigma2,U count)
{count count U(1); // 每次调用这个函数就把处理的数据数量加一。U delta curr - mu; // 看看新数据和现有平均值差多少。U lmean mu delta / count; // 用这个差值和数据总量来算一个新的平均值。mu lmean; // 把这个新算的平均值记下来。U delta2 curr - lmean; // 现在再算一下新数据和新平均值的差。sigma2 sigma2 delta * delta2; // 利用这个新旧平均值的差来更新方差。
}// 这段代码是个CUDA函数名叫cuChanOnlineSum。它用于处理一种特殊的情况
// 当你有两堆数据想要快速算出它们合并后的平均值和方差时这个函数就派上用场了。
// const U muB, sigma2B, countB: 这三个是你新加入的那堆数据的平均值、方差和数据点数量。
// U mu, sigma2, count: 这三个是你之前已经有的数据的平均值、方差和数据点数量。
// 这个函数会更新这些值让它们反映出两堆数据合并后的情况。templatetypename U __device__
void cuChanOnlineSum(const U muB,const U sigma2B,const U countB,U mu,U sigma2,U count)
{U delta muB - mu; // 先算算新数据堆和老数据堆的平均值差了多少。U nA count; // 记下当前数据堆我们叫它A堆的大小。U nB countB; // 看看新来的那堆数据B堆有多少个点。count count countB; // 把两堆数据的数量加起来。U nX count; // 这就是合并后总数据量的大小。if (nX U(0)) {nA nA / nX; // 算一下A堆数据在总数据中占的比例。nB nB / nX; // 同理算一下B堆的比例。mu nA*mu nB*muB; // 利用这些比例和各自的平均值算出总的平均值。sigma2 sigma2 sigma2B delta * delta * nA * nB * nX; // 然后用一点复杂的公式把方差也算出来这个公式考虑了两堆数据的方差和它们平均值的差异。} else {// 如果合并后的总数是0那就说明两堆数据其实都是空的所以把平均值和方差都设为0。mu U(0);sigma2 U(0);}
}// 这里定义了一个名叫cuRMSOnlineSum的CUDA函数它的主要任务就是在线实时计算一串数据的平方和。
// 你可能会问为什么要算平方和呢这是因为我们可以用它来算出均方根RMS, Root Mean Square
// 均方根是一种描述数据波动大小的指标特别常用于信号处理领域。
templatetypename U __device__
void cuRMSOnlineSum(const U curr,U sigma2)
{sigma2 sigma2 curr * curr; // 每次函数被调用就把当前值的平方加到累计平方和中。
}// 又定义了一个名叫cuChanRMSOnlineSum的CUDA函数这个家伙的工作就是帮你算两组数据的平方和总和。
// 当你有两组数据想要快速合并它们的均方根RMS时这个函数就能派上用场。
// 它其实是均方根计算过程中的一个环节用于处理两个独立数据集的情况。
templatetypename U __device__
void cuChanRMSOnlineSum(const U sigma2B,U sigma2)
{sigma2 sigma2 sigma2B; // 这里就简单直接了把第二组数据的平方和加到当前的累计值上。
} 这里还有一个函数cuWelfordMuSigma2是用来计算张量某一维度上的均值mu和方差sigma2的它调用了上面的工具函数但是这个函数我们在kernel实现阶段解析因为它需要一些kernel启动的背景。
0x1.2 启动逻辑
先对kernel启动这部分的代码进行注释首先是共享内存的结构体定义。
// 这段代码定义了一个叫做SharedMemory的模板结构体专门用在CUDA设备函数里来访问所谓的“共享内存”。
// 在CUDA编程里共享内存是一种特别高效的内存类型非常适合用来在CUDA的一个块block内的不同线程间共享数据。
// 这里还包括了针对float和double类型数据的SharedMemory结构体的特化版本。namespace {
// 这是通用的SharedMemory结构体模板。注意我们通过在函数体内使用一个未定义的符号来阻止这个结构体被实例化
// 这样如果尝试用未特化的类型来编译这个结构体编译器就会报错。
// template typename T
// struct SharedMemory
// {
// // 确保我们不会编译任何未特化的类型
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };template typename T
struct SharedMemory;// 这是SharedMemory结构体针对float类型的特化版本。
template
struct SharedMemory float
{// 这个函数返回一个指向共享内存的float类型指针。__device__ float *getPointer(){ // 这里声明了一个名为s_float的外部共享内存数组用于存储float类型的数据。// extern和__shared__关键字表明这个数组是在共享内存中定义的。extern __shared__ float s_float[];return s_float;}
};// 下面是针对double类型的特化版本工作方式和float版本相似。
template
struct SharedMemory double
{__device__ double *getPointer(){extern __shared__ double s_double[];return s_double;}
};
}然后是Kernel启动的具体逻辑部分
// 这段代码里我们定义了一个CUDA设备函数叫做cuApplyLayerNorm_它的主要任务是执行LayerNorm层归一化。
// 层归一化是深度学习中的一个技巧用来让每一层的输出更加标准化有助于模型训练。
// 我们定义了三种模板参数T是输入数据类型U是中间计算比如均值和方差的类型V是输出数据类型。
// output_vals, mean, invvar, vals, gamma, beta 这些都是指向不同数据的指针。
// 在层归一化中我们通常把一个多维数据张量分为两部分一部分用来做标准化另一部分保持原样。
// 比如如果你有一个 [batch_size, channels, height, width] 形状的4D张量
// 而你只想对最后两个维度进行层归一化那么n1是batch_size * channelsn2是height * width。
templatetypename T, typename U, typename V __device__
void cuApplyLayerNorm_(V* __restrict__ output_vals,U* __restrict__ mean,U* __restrict__ invvar,const T* __restrict__ vals,const int n1,const int n2,const U epsilon,const V* __restrict__ gamma,const V* __restrict__ beta,bool rms_only)
{// 基本假设// 1) blockDim.x 是 warp 的大小这是一个CUDA的技术细节。// 2) 输入的张量数据在内存中是连续的。//// 这段代码遍历n1维度每次处理一个i1索引。// 假设每个CUDA线程块的x维度等于warp大小确保数据处理是高效的。// 这里一个线程可能要处理多行所以我们用gridDim.y来控制步长。因为gridDim.x1for (auto i1blockIdx.y; i1 n1; i1 gridDim.y) {SharedMemoryU shared;U* buf shared.getPointer(); // 创建一个 SharedMemory 实例用于处理类型 U 的数据。U mu,sigma2; // 这里mu和sigma2分别代表均值和方差我们接下来要计算它们。// 调用 cuWelfordMuSigma2 函数计算给定索引 i1 处的均值mu和方差sigma2。cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only);// 定位到当前 i1 索引处的输入和输出的起始位置。const T* lvals vals i1*n2;V* ovals output_vals i1*n2;// 计算逆方差 c_invvar这是层归一化中一个关键的步骤。U c_invvar rsqrt(sigma2 epsilon);// 计算每个 CUDA 块的线程总数 (numx) 和当前线程的一维索引 (thrx)。const int numx blockDim.x * blockDim.y;const int thrx threadIdx.x threadIdx.y * blockDim.x;// 如果提供了gamma和beta参数或者我们只是在做RMS计算我们会用一种特别的方式来计算输出值。if (gamma ! NULL (beta ! NULL || rms_only)) {for (int i thrx; i n2; inumx) {U curr static_castU(lvals[i]);if (!rms_only) {// 标准化当前值然后用gamma和beta进行调整。ovals[i] gamma[i] * static_castV(c_invvar * (curr - mu)) beta[i];} else {// // 如果是RMS模式我们稍微简化计算过程。ovals[i] gamma[i] * static_castV(c_invvar * curr);}}} // 否则如果没有提供gamma和beta我们就直接用计算出的均值和逆方差来进行标准化。else {for (int i thrx; i n2; inumx) {U curr static_castU(lvals[i]);if (!rms_only) {// 直接进行标准化计算。ovals[i] static_castV(c_invvar * (curr - mu));} else {// // RMS模式下的简化计算。ovals[i] static_castV(c_invvar * curr);}}}// 在每个 CUDA 块中仅由一个线程线程 (0,0)更新均值和逆方差。if (threadIdx.x 0 threadIdx.y 0) {if (!rms_only) {mean[i1] mu;}invvar[i1] c_invvar;}// 用于同步块内的所有线程。__syncthreads();}
}// 对上个函数的参数透传不过rms_only设为False
templatetypename T, typename U, typename VT __global__
void cuApplyLayerNorm(V* __restrict__ output_vals,U* __restrict__ mean,U* __restrict__ invvar,const T* __restrict__ vals,const int n1,const int n2,const U epsilon,const V* __restrict__ gamma,const V* __restrict__ beta)
{cuApplyLayerNorm_T, U, V(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false);
}// kernel启动代码设置线程块和线程数
templatetypename T, typename U, typename VT
void HostApplyLayerNorm(V* output,U* mean,U* invvar,const T* input,int n1,int n2,double epsilon,const V* gamma,const V* beta)
{// threads和blocks定义了CUDA内核的线程和块的维度。这里每个线程块有32×4的线程而块的数量由n1和GPU设备的最大网格大小限制决定。auto stream at::cuda::getCurrentCUDAStream().stream();const dim3 threads(32,4,1);const uint64_t maxGridY at::cuda::getCurrentDeviceProperties()-maxGridSize[1];const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);// 这段代码计算内核函数需要多少共享内存。如果threads.y大于1它会根据U类型的大小分配足够的内存。int nshared threads.y 1 ?threads.y*sizeof(U)(threads.y/2)*sizeof(U) :0;// 最后函数使用cuApplyLayerNorm kernel来执行实际的LayerNorm操作。// kernel函数的调用使用了之前计算的线程块和线程配置以及共享内存大小和CUDA流。cuApplyLayerNormblocks, threads, nshared, stream(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
}这段代码包含了kernel的启动逻辑包括设置block的个数以及每个block中的线程排布方式然后在cuApplyLayerNorm_里面有一个跨线程网格的大循环作用在n1维度每个线程可能会处理多行数据。而在每一行数据的处理上调用了cuWelfordMuSigma2 函数计算给定索引 i1 处的均值mu和方差sigma2并随后在n2维度上来计算LayerNorm的输出同时会在每个Block的线程0, 0更新cuWelfordMuSigma2算出来的均值和方差这里的记录的实际上是方差的逆。
0x1.3 kernel实现
从上面的分析可知整个LayerNorm实现的核心就是cuWelfordMuSigma2函数下面对这个函数进行解析。
// cuWelfordMuSigma2 是一个CUDA设备函数旨在高效计算张量某一特定维度上的均值mu和方差sigma2。
// 它基于Welford算法实现以提高数值稳定性。此外该函数支持仅计算均方根RMS作为一种操作模式。
// 模板参数 typename T, typename U: 定义了处理张量值T和执行计算U时使用的数据类型。
// const T* __restrict__ vals: 指向张量数据的指针。
// const int n1, n2: 指定张量的维度其中n1是参与计算的维度的大小n2是被约减的维度的大小。
// const int i1: 当前正在处理的n1维度上的特定索引。
// U mu, sigma2: 用于存储计算得出的均值和方差。
// U* buf: 指向用于线程间通讯的共享内存缓冲区的指针。
// bool rms_only: 一个标志用于指示是否仅计算RMS为true时或同时计算均值和方差为false时。
templatetypename T, typename U __device__
void cuWelfordMuSigma2(const T* __restrict__ vals,const int n1,const int n2,const int i1,U mu,U sigma2,U* buf,bool rms_only)
{// 前提条件:// 1) blockDim.x 等于 warp 的大小。// 2) 输入的张量在内存中连续存储。// 3) 有足够的共享内存可用大小为 2*blockDim.y*sizeof(U) blockDim.y*sizeof(int)。//// 在 n2 维度上计算方差和均值。// 初始化 count, mu, 和 sigma2 为零。U count U(0);mu U(0);sigma2 U(0);// 确保处理的 i1 索引在张量的有效范围内。if (i1 n1) {// 一个warp处理一个n1索引同步是隐含的。// 用标准的Welford算法进行初始化。const int numx blockDim.x * blockDim.y; // 计算一个 CUDA 块中的线程总数。const int thrx threadIdx.x threadIdx.y * blockDim.x; // 计算当前线程在块内的唯一线性索引// 将 lvals 指针设置为指向当前处理的 i1 索引处张量的开始位置。// vals 是整个张量数据的起始指针i1*n2 计算出当前索引在张量中的线性位置。const T* lvals vals i1*n2;// 用局部变量l遍历张量的元素每个线程处理多个元素。int l 4*thrx;// 遍历张量的元素步长为4*numx每个线程处理四个元素。for (; l3 n2; l4*numx) {// 在每次外循环的迭代中处理四个连续的元素。for (int k 0; k 4; k) {// 将当前处理的元素值转换为计算使用的数据类型U。U curr static_castU(lvals[lk]);// 根据 rms_only 标志调用相应的函数来更新均值和方差或仅更新平方和用于计算 RMS。if (!rms_only) {cuWelfordOnlineSumU(curr,mu,sigma2,count);} else {cuRMSOnlineSumU(curr, sigma2);}}}// 这个循环处理了之前在步长为 4*numx 的循环中未处理的张量元素。每个线程独立处理它们剩余的部分。for (; l n2; l) {U curr static_castU(lvals[l]);if (!rms_only) {cuWelfordOnlineSumU(curr,mu,sigma2,count);} else {cuRMSOnlineSumU(curr, sigma2);}}// 在同一个warp内进行归约操作。for (int l 0; l 4; l) {// 是在 CUDA 设备上进行 warp 内部数据交换的关键部分。// 这行代码用于确定在一个 warp32个线程内每个线程应该从哪个“lane”即其他线程获取数据。// 1l这个操作在这里用于逐步增加要从中获取数据的线程的距离。例如当 l 为 0 时// 线程将从它的“邻居”线程即下一个线程获取数据当 l 为 1 时它将从两个位置之外的线程获取数据依此类推。// 这个表达式计算出当前线程应该从哪个线程获取数据。随着 l 的增加每个线程从越来越远的线程获取数据。// 31是因为在一个 warp 内线程索引是循环的。也就是说如果一个线程的索引计算结果是 32// 它实际上会从索引为 0 的线程获取数据索引为 33 的线程实际上是索引为 1 的线程依此类推。int srcLaneB (threadIdx.x(1l))31;// 是一种 warp 内部的快速数据交换操作用于从另一个线程srcLaneB获取数据。U sigma2B WARP_SHFL(sigma2, srcLaneB);// 如果不是只计算 RMS!rms_only则使用 cuChanOnlineSum 合并两个线程的 mu、sigma2 和 count。// 如果只计算 RMS则使用 cuChanRMSOnlineSum 合并 sigma2。if (!rms_only) {U muB WARP_SHFL(mu, srcLaneB);U countB WARP_SHFL(count, srcLaneB);cuChanOnlineSumU(muB,sigma2B,countB,mu,sigma2,count);} else {cuChanRMSOnlineSumU(sigma2B, sigma2);}}// threadIdx.x 0 has correct values for each warp// inter-warp reductions// 检查是否有多个 warp。如果 blockDim.y 大于 1则表示块中有多个 warp 需要进行reduce操作。if (blockDim.y 1) {// 为方差和均值的reduce操作分配共享内存。ubuf 用于存储方差和均值ibuf 用于存储计数。U* ubuf (U*)buf;U* ibuf (U*)(ubuf blockDim.y);// 这个循环是对 warp 间的reduce操作进行分层合并。for (int offset blockDim.y/2; offset 0; offset / 2) {// upper half of warps write to shared// 确保只有部分线程warp 的上半部分将其计算的结果写入共享内存。if (threadIdx.x 0 threadIdx.y offset threadIdx.y 2*offset) {const int wrt_y threadIdx.y - offset;if (!rms_only) {ubuf[2*wrt_y] mu;ibuf[wrt_y] count;}ubuf[2*wrt_y1] sigma2;}// 同步以等待共享内存存储完毕__syncthreads();// lower half merges// 此部分是对 warp 间数据的合并操作。// 确保只有部分线程warp 的下半部分从共享内存中读取数据并进行合并。if (threadIdx.x 0 threadIdx.y offset) {U sigma2B ubuf[2*threadIdx.y1];if (!rms_only) {U muB ubuf[2*threadIdx.y];U countB ibuf[threadIdx.y];cuChanOnlineSumU(muB,sigma2B,countB,mu,sigma2,count);} else {cuChanRMSOnlineSumU(sigma2B,sigma2);}}__syncthreads();}// threadIdx.x 0 threadIdx.y 0 only thread that has correct values// 最终的结果由块内的第一个线程threadIdx.x 0 threadIdx.y 0计算并写入共享内存。if (threadIdx.x 0 threadIdx.y 0) {if (!rms_only) {ubuf[0] mu;}ubuf[1] sigma2;}__syncthreads();// 如果不是只计算 RMS则还需要更新均值 mu。if (!rms_only) {mu ubuf[0];}// 计算最终的方差。sigma2 ubuf[1]/U(n2);// dont care about final value of count, we know count n2} // 如果块中只有一个 warpblockDim.y 1则通过 WARP_SHFL 直接在 warp 内进行数据交换和更新。else {if (!rms_only) {mu WARP_SHFL(mu, 0);}sigma2 WARP_SHFL(sigma2/U(n2), 0);}}cuWelfordMuSigma2函数就是在n2维度上使用工具函数章节的Weleford方法来完成均值和方差的计算然后这里还借助了共享内存来做warp内和warp间的reduce最终得到全局的均值和方差。
前向的kernel就分析到这里大家如果想对LayerNorm的优化做进一步的了解推荐看一下OneFlow的SoftMax和LayerNorm优化文章。CUDA优化之LayerNorm性能优化实践(https://zhuanlan.zhihu.com/p/443026261) 这篇文章也是讲解了LayerNorm的前向优化流程文章开头有一张性能的图
实际上在大模型时代我们的隐藏层维度已经越来越大了所以我们在实际训练的时候OneFlow版本的kernel相比于apex的layerNorm在13B之类的模型训练里就拿不到明显收益了。而在CV中由于做LayerNorm的维度可能相对小一些所以相比于apex的LayerNorm就可以取得明显加速。
0x2. Apex的LayerNorm反向cuda实现memory_efficient相关计算
在apex的LayerNorm反向实现时我们不仅要关注它的cuda kernel是怎么写的还要关注memory_efficient打开时是如何根据输出来计算梯度的。我们知道LayerNorm需要对输入gammabeta都计算梯度介于篇幅原因这里对实现得最复杂的gamma/beta的反向过程进行走读。
0x2.1 启动逻辑
这里从kernel的启动逻辑开始梳理
// 这是一个模板函数支持不同的数据类型T输入数据类型、
// U通常用于中间计算的数据类型默认为float、V输出数据类型默认与T相同。
// 参数包括输出梯度dout、均值mean、方差倒数invvar、输入或输出的PyTorch张量input_or_output、
// 两个维度参数n1、n2、gamma和beta参数、用于数值稳定的epsilon、输入梯度grad_input、
// gamma梯度grad_gamma和beta梯度grad_beta、以及一个指示是否优化内存使用的布尔值memory_efficient。
templatetypename T, typename Ufloat, typename VT
void HostLayerNormGradient(const V* dout,const U* mean,const U* invvar,at::Tensor* input_or_output,int n1,int n2,const V* gamma,const V* beta,double epsilon,T* grad_input,V* grad_gamma,V* grad_beta,bool memory_efficient)
{// 获取当前CUDA流以用于后续的CUDA内核调用。auto stream at::cuda::getCurrentCUDAStream().stream();// 如果gamma和beta不为NULL函数会计算它们的梯度。// 这涉及两个CUDA内核的调用cuComputePartGradGammaBeta和cuComputeGradGammaBeta。if (gamma ! NULL beta ! NULL) {// compute grad_gamma(j) and grad_beta(j)// part_size是分块计算梯度时的部分大小。const int part_size 16;// threads2定义了每个CUDA线程块中的线程数量32×4×1。const dim3 threads2(32,4,1);// blocks2定义了CUDA网格中的块数量其中n2维度被分成多个块以确保每个块可以处理n2中的一部分。const dim3 blocks2((n2threads2.x-1)/threads2.x,part_size,1);// 这部分代码计算用于CUDA内核的共享内存大小。nshared2_a和nshared2_b是基于线程和块维度的两种不同共享内存大小估算。const int nshared2_a 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x 1);const int nshared2_b threads2.x * threads2.y * sizeof(U);// 最终选择较大的一个估算值作为实际的共享内存大小nshared2。const int nshared2 nshared2_a nshared2_b ? nshared2_a : nshared2_b;// note (mkozuki): I can hard code part_grad_gammas dtype as float given that// the cuda_layer_norm_gradient doesnt support double.// 根据输入或输出张量的数据类型决定局部梯度张量part_grad_gamma和part_grad_beta的数据类型。// 如果输入或输出是半精度浮点数Half或BFloat16则使用单精度浮点数Float否则使用输入或输出的相同数据类型。const auto part_grad_dtype (input_or_output-scalar_type() at::ScalarType::Half || input_or_output-scalar_type() at::ScalarType::BFloat16) ?at::ScalarType::Float :input_or_output-scalar_type();// 创建两个新的PyTorch张量part_grad_gamma和part_grad_beta用于存储gamma和beta的局部梯度计算结果。at::Tensor part_grad_gamma at::empty({part_size,n2}, input_or_output-options().dtype(part_grad_dtype));at::Tensor part_grad_beta at::empty_like(part_grad_gamma);// 使用BOOL_SWITCH宏处理memory_efficient参数以决定是否使用内存高效版本的CUDA内核。// 调用cuComputePartGradGammaBeta内核计算gamma和beta的梯度。// 这个内核函数接收必要的输入参数并将梯度结果写入part_grad_gamma和part_grad_beta张量。BOOL_SWITCH(memory_efficient, MemoryEfficient, []{auto kernel cuComputePartGradGammaBetaT, U, V, MemoryEfficient;kernelblocks2, threads2, nshared2, stream(dout,input_or_output-DATA_PTRT(),n1,n2,mean,invvar,U(epsilon),gamma,beta,part_grad_gamma.DATA_PTRU(),part_grad_beta.DATA_PTRU(),epsilon,false);});// 定义了每个CUDA线程块中的线程数量32×8×1。const dim3 threads3(32,8,1);// 定义了CUDA网格中的块数量。在这里n2维度被分成多个块每个块的大小由threads2.x之前定义的线程数量确定。const dim3 blocks3((n2threads2.x-1)/threads2.x,1,1);// 这行代码计算了cuComputeGradGammaBeta内核所需的共享内存大小。它基于threads3线程块的维度和数据类型U的大小。const int nshared3 threads3.x * threads3.y * sizeof(U);// kernel 接收局部梯度张量part_grad_gamma和part_grad_beta、块大小part_size、// 维度参数n1和n2和指向梯度输出的指针grad_gamma和grad_beta。cuComputeGradGammaBetablocks3, threads3, nshared3, stream(part_grad_gamma.DATA_PTRU(),part_grad_beta.DATA_PTRU(),part_size,n1,n2,grad_gamma,grad_beta,false);}...
}这里省略了计算输入梯度的启动代码只看计算gamma和beta梯度的代码。可以发现这里对gamma和beta的梯度进行计算时使用了分块计算的方式首先会调用cuComputePartGradGammaBeta这个kernel计算出一个部分gamma和部分beta也就是part_grad_gamma和part_grad_beta需要注意这个kernel开启的线程块为const dim3 blocks2((n2threads2.x-1)/threads2.x,part_size,1)其中part_size16此外每个线程块中的线程排布为const dim3 threads2(32,4,1)即每个线程块有128个线程。我们可以简单算一下block2的大小threads2.x32那么blocks2(n2/32,16,1)也就是一共会有n2/2个线程块。
使用cuComputePartGradGammaBeta计算完局部gamma和beta的grad之后会调用cuComputeGradGammaBeta这个kernel来汇总全局的gamma和beta的梯度。这里开启的线程块为const dim3 blocks3((n2threads2.x-1)/threads2.x,1,1)而每个线程块里面有256个线程排布为const dim3 threads3(32,8,1)。
现在了解了线程块的组织方式就需要去kernel实现里面对应看一下具体是怎么计算的。
0x2.2 kernel计算逻辑
首先来看分段计算gamma和beta梯度的kernel实现注释如下
// part_size是分块计算梯度时的部分大小。
// const int part_size 16;
// threads2定义了每个CUDA线程块中的线程数量32×4×1。
// const dim3 threads2(32,4,1);
// blocks2定义了CUDA网格中的块数量其中n2维度被分成多个块以确保每个块可以处理n2中的一部分。
// const dim3 blocks2((n2threads2.x-1)/threads2.x,part_size,1);
// -
// blockDim.x 32, blockDim.y 4, gridDim.y 16
// 假设 n1 4, n2 256并且当前是第一个线程块
templatetypename T, typename U, typename V, bool MemoryEfficient __global__
void cuComputePartGradGammaBeta(const V* __restrict__ dout,const T* __restrict__ input_or_output,const int n1,const int n2,const U* __restrict__ mean,const U* __restrict__ invvar,U epsilon,const V* __restrict__ gamma,const V* __restrict__ beta,U* part_grad_gamma,U* part_grad_beta,const double eps,bool rms_only)
{// numsegs_n1计算n1维度4被分成多少段。使用blockDim.y*blockDim.y16作为分段大小。// 带入值numsegs_n1 (4 16 - 1) / 16 1。const int numsegs_n1 (n1blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);// segs_per_block计算每个线程块要处理的段数。// 带入值segs_per_block (1 16 - 1) / 16 1。const int segs_per_block (numsegs_n1 gridDim.y - 1) / gridDim.y;// 这些行计算当前线程块开始和结束处理n1维度的索引// i1_beg和i1_beg_plus_one相差segs_per_block * blockDim.y*blockDim.y1*4*416// 带入blockIdx.y 0i1_beg 0 * 1 * 4 * 4 0 i1_beg_plus_one 1 * 1 * 4 * 4 16i1_end min(16, 4) 4const int i1_beg blockIdx.y * segs_per_block * blockDim.y*blockDim.y;const int i1_beg_plus_one (blockIdx.y1) * segs_per_block * blockDim.y*blockDim.y;const int i1_end i1_beg_plus_one n1 ? i1_beg_plus_one : n1;// row_stride用于数组访问防止bank conflict。这里等于33const int row_stride blockDim.x1;// 计算每个线程在数据块中的偏移量。// 下限是0上限是31。const int thr_load_col_off (threadIdx.x*blockDim.y)(blockDim.x-1);// 这里的下限是0上限就是(31 * 4) / 32 3 * 415const int thr_load_row_off (threadIdx.x*blockDim.y)/blockDim.x threadIdx.y*blockDim.y;// i2_off是n2维度上的偏移量。const int i2_off blockIdx.x * blockDim.x thr_load_col_off;// 分配共享内存并设置两个缓冲区warp_buf1和warp_buf2SharedMemoryU shared;U* buf shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y (blockDim.y - 1)*(blockDim.x/blockDim.y) elementsU* warp_buf1 (U*)buf; // 大小是 31 * 4 * 4 496U* warp_buf2 warp_buf1 blockDim.y * blockDim.y * row_stride; // 大小是 3 * (32 / 4) 24// compute partial sums from strided inputs// do this to increase number of loads in flightcuLoadWriteStridedInputsT, U, V, MemoryEfficient(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only);// for循环处理每个数据块由i1_beg和i1_end确定。// 它在数据块之间以步幅blockDim.y*blockDim.y迭代允许不同的线程块处理不同的数据区域。for (int i1_block i1_begblockDim.y*blockDim.y; i1_block i1_end; i1_blockblockDim.y*blockDim.y) {cuLoadAddStridedInputsT, U, V, MemoryEfficient(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only);}// 确保在所有线程完成其加载和处理操作之前没有线程会继续执行后续的操作。__syncthreads();// inter-warp reductions// sum within each warp// 这部分代码执行内部归约计算每个warp内部的部分和。// acc1和acc2分别用于累积来自warp_buf1和warp_buf2的值。这些缓冲区包含之前步骤计算的中间结果。U acc1 U(0);U acc2 U(0);// 内部循环对于blockDim.y内的每一行进行累加if (!rms_only)条件检查是否需要执行特定的分支逻辑。// 需要特别注意这个累加实际上是在列方向上也就是n2维度在n2维度上一个线程负责计算blockDim.y列for (int k 0; k blockDim.y; k) {int row1 threadIdx.y k*blockDim.y;int idx1 row1*row_stride threadIdx.x;if (!rms_only) {acc1 warp_buf1[idx1];}acc2 warp_buf2[idx1];}// 累加的结果被写回到warp_buf1和warp_buf2中对应的位置。if (!rms_only) {warp_buf1[threadIdx.y*row_stridethreadIdx.x] acc1;}warp_buf2[threadIdx.y*row_stridethreadIdx.x] acc2;// 再次同步__syncthreads();// sum all warps// 这个循环是归约操作的一部分用于在warp之间求和。// offset初始化为blockDim.y/2每次迭代都减半这是一种常用的并行归约模式。// 这个也是在n2维度做for (int offset blockDim.y/2; offset 1; offset / 2) {// 在每次迭代中只有threadIdx.y小于当前offset的线程会参与计算这样可以避免重复的工作。if (threadIdx.y offset) {// idx1和idx2是基于row_stride计算的索引用于访问warp_buf1和warp_buf2。int row1 threadIdx.y;int row2 threadIdx.y offset;int idx1 row1*row_stride threadIdx.x;int idx2 row2*row_stride threadIdx.x;// 对于rms_only模式的分支处理只有在非rms_only模式下才更新warp_buf1。if (!rms_only) {warp_buf1[idx1] warp_buf1[idx2];}warp_buf2[idx1] warp_buf2[idx2];}// __syncthreads()在每次归约后同步所有线程确保所有线程都完成了它们的部分归约工作。__syncthreads();}// i2是一个计算出的索引表示当前线程在n2维度上的位置。int i2 blockIdx.x * blockDim.x threadIdx.x;// 如果当前线程的threadIdx.y为0且i2小于n2则线程参与最终梯度的计算。if (threadIdx.y 0 i2 n2) {int row1 threadIdx.y;int row2 threadIdx.y 1;int idx1 row1*row_stride threadIdx.x;int idx2 row2*row_stride threadIdx.x;if (!rms_only) {part_grad_beta[blockIdx.y*n2i2] warp_buf1[idx1] warp_buf1[idx2];}part_grad_gamma[blockIdx.y*n2i2] warp_buf2[idx1] warp_buf2[idx2];}
}在理解这段代码之前有一个大前提那就是这里的访问方式是n1是和blockDim.y绑定的而n2是和blockDim.x绑定的而且以二维矩阵的角度来看n1是在列方向而n2是在行的方向。然后const int row_stride blockDim.x1这一行是对共享内存进行padding避免Bank Conflict的而在计算时对共享内存的访问就是按照列来访问彻底避免bank conflict。
这也是为什么cuLoadWriteStridedInputs和cuLoadAddStridedInputs函数名中有一个Strided这也暗示了它们的访问模式是跨stride的。剩下的部分其实和前向就比较类似了做warp内和warp间的reduce。
另外一个值得注意的点是在cuLoadWriteStridedInputs和cuLoadAddStridedInputs计算时会根据memory_efficient开关选择不同的计算公式分别从输入和输出来计算出梯度达到kernel内部重计算的目的。
// 这段代码定义了一个名为 cuLoadWriteStridedInputs 的 CUDA 设备函数模板用于在计算LayerNorm的梯度时
// 从输入张量中加载数据并进行必要的计算将结果存储在 warp 缓冲区中。这个函数支持内存高效模式MemoryEfficient。
// 模板参数 T, U, V 代表不同的数据类型。
// bool MemoryEfficient 用于选择是否采用内存高效的方式处理数据。
// __device__ 表明这是一个 CUDA 设备函数。
// 函数参数包括各种用于LayerNorm梯度计算的数据
// 如输入/输出张量、梯度张量 dout、均值 mean、逆方差 invvar、缩放参数 gamma、偏移参数 beta 等。
templatetypename T, typename U, typename V, bool MemoryEfficient __device__
void cuLoadWriteStridedInputs(const int i1_block,const int thr_load_row_off,const int thr_load_col_off,const int i2_off,const int row_stride,U* warp_buf1,U* warp_buf2,const T* input_or_output,const V* dout,const int i1_end,const int n2,const U* __restrict__ mean,const U* __restrict__ invvar,const V* __restrict__ gamma,const V* __restrict__ beta,const double eps,bool rms_only)
{// 计算 i1表示当前处理的行索引。int i1 i1_blockthr_load_row_off;if (i1 i1_end) {for (int k 0; k blockDim.y; k) {// 计算列索引 i2 和用于加载和写入数据的索引。int i2 i2_off k;// load_idx 是从输入张量读取数据的索引write_idx 是在 warp 缓冲区写入数据的索引。int load_idx i1*n2i2;int write_idx thr_load_row_off*row_stridethr_load_col_offk;// 如果 i2 在有效范围内则从输入张量加载数据并进行必要的计算。if (i2n2) {U c_h static_castU(input_or_output[load_idx]);U curr_dout static_castU(dout[load_idx]);// 根据 rms_only 和 MemoryEfficient 的值使用不同的公式计算梯度并将结果存储在 warp 缓冲区中。if (!rms_only) {warp_buf1[write_idx] curr_dout;if (MemoryEfficient) {U curr_beta static_castU(beta[i2]);warp_buf2[write_idx] curr_dout * (c_h - curr_beta) / static_castU(clamp_by_magnitude(gamma[i2], eps));} else {warp_buf2[write_idx] curr_dout * (c_h - mean[i1]) * invvar[i1];}} else {if (MemoryEfficient) {warp_buf2[write_idx] curr_dout * (c_h) / static_castU(clamp_by_magnitude(gamma[i2], eps));} else {warp_buf2[write_idx] curr_dout * (c_h) * invvar[i1];}}} else {// 对于超出 n2 范围的索引将相应的 warp 缓冲区位置设置为 0。if (!rms_only) {warp_buf1[write_idx] U(0);}warp_buf2[write_idx] U(0);}}} else {// 对于超出 n1 范围的索引也将相应的 warp 缓冲区位置设置为 0。for (int k 0; k blockDim.y; k) {int write_idx thr_load_row_off*row_stridethr_load_col_offk;if (!rms_only) {warp_buf1[write_idx] U(0);}warp_buf2[write_idx] U(0);}}
}执行完cuComputePartGradGammaBeta这个kernel之后它的输出part_grad_gamma和part_grad_beta分别以行为n2列为n1的内存视角保存了LayerNorm的局部均值和方差的梯度。
接下来会使用cuComputeGradGammaBeta这个kernel来计算全局的均值和方差的梯度由于局部计算的时候分块大小是16而每个线程负责了4行的计算那么这里还需要累积16/44次以得到当前行的所有局部梯度的和。
// blockDim.x n2 / 32, blockDim.y 1
// threadDim.x 32, threadDim.y 8
templatetypename U, typename V __global__
void cuComputeGradGammaBeta(const U* part_grad_gamma,const U* part_grad_beta,const int part_size,const int n1,const int n2,V* grad_gamma,V* grad_beta,bool rms_only)
{// sum partial gradients for gamma and betaSharedMemoryU shared;U* buf shared.getPointer();// 计算每个线程的全局索引i2用于确定它在n2维度上的位置。int i2 blockIdx.x * blockDim.x threadIdx.x;// 如果线程索引i2小于n2的大小该线程会参与计算。if (i2 n2) {// each warp does sequential reductions until reduced part_size is num_warps// num_warp_reductions计算了每个warp需要进行的归约次数// 这取决于part_size每个块的大小和blockDim.y线程块的y维度大小。int num_warp_reductions part_size / blockDim.y;// sum_gamma和sum_beta分别用于累加gamma和beta的局部梯度。U sum_gamma U(0);U sum_beta U(0);// 这两行为每个线程设置了指向局部梯度数组的指针。这些指针根据线程的y索引和全局索引i2来确定。const U* part_grad_gamma_ptr part_grad_gamma threadIdx.y * num_warp_reductions * n2 i2;const U* part_grad_beta_ptr part_grad_beta threadIdx.y * num_warp_reductions * n2 i2;// 在这个循环中每个线程累加它负责的局部梯度部分。// warp_offset变量用于遍历分配给每个线程的梯度段。for (int warp_offset 0; warp_offset num_warp_reductions; warp_offset) {sum_gamma part_grad_gamma_ptr[warp_offset*n2];if (!rms_only) {sum_beta part_grad_beta_ptr[warp_offset*n2];}}// inter-warp reductions// nbsize3计算了用于存储归约中间结果的共享内存大小。// 在这个LayerNorm的启动参数下这里的blockDim.y是恒定为1的所以实际上这个reduce不会工作const int nbsize3 blockDim.x * blockDim.y / 2;// 外部循环减少归约参与的线程数每次迭代减半。for (int offset blockDim.y/2; offset 1; offset / 2) {// top half write to shared memory// 在这个归约阶段线程首先将其累加结果写入共享内存然后从共享内存读取并继续累加。if (threadIdx.y offset threadIdx.y 2*offset) {const int write_idx (threadIdx.y - offset) * blockDim.x threadIdx.x;buf[write_idx] sum_gamma;if (!rms_only) {buf[write_idxnbsize3] sum_beta;}}// __syncthreads()在每次迭代结束时同步所有线程确保共享内存的一致性。__syncthreads();// bottom half sumsif (threadIdx.y offset) {const int read_idx threadIdx.y * blockDim.x threadIdx.x;sum_gamma buf[read_idx];if (!rms_only) {sum_beta buf[read_idxnbsize3];}}__syncthreads();}// write out fully summed gradients// 如果线程是其warp中的第一个threadIdx.y 0它负责将完全累加的梯度写入全局内存。if (threadIdx.y 0) {grad_gamma[i2] sum_gamma;if (!rms_only) {grad_beta[i2] sum_beta;}}}
}注意for (int offset blockDim.y/2; offset 1; offset / 2) 这个循环包起来的代码在这里不会工作因为这个kernel的启动设置中 blockDim.y1。另外我们知道输入的数据已经是写到全局内存里面的了已经是同步之后的了然后每个线程累积4次这个过程也是从global memory里面先读再计算最后写回全局内存所以确实不需要再reduce了。
关于memory_efficient开关打开时的梯度计算公式按照 https://github.com/NVIDIA/apex/pull/1715 这个pr 来看应该就是把原始的输入用重计算的输入替换之后再代入到之前的梯度计算公式中算出来的。 https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L579 这里就对应了对gamma的梯度https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L582C5-L582C5 这里则对应了对beta的梯度。这里的 X ^ \hat{X} X^就等于 ( y − β ) / g a m m a (y-\beta)/gamma (y−β)/gamma公式和代码实现都能完整对应上。
0x3. 总结
这篇文章记录了笔者在研究大模型训练中偶然见到的一个Trick的代码解密过程希望对学习cuda的小伙伴有所帮助谢谢大家。