文件乱码了怎么恢复,百度网站怎么优化排名,wordpress网站首页链接乱码,wordpress登录ssoTransformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述
模型训练框架#xff1a;例如pytorch框架的cuda context… Transformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述
模型训练框架例如pytorch框架的cuda context会占用大约几百MB显存与版本有关模型参数大小比如7B的模型以FP16格式要占用14GB显存前向计算过程中产生的临时Tensor这部分Tensor需要被临时保存以便在反向传播计算梯度时使用反向传播计算得到的梯度优化器状态全量微调的情况下梯度与参数一样大普通SGD没有动量一阶动量优化器的自身参数大小与模型大小一样比如momentum-SGD二阶动量优化器一般为模型大小的两倍比如Adam transformer系列的大模型最常用的是Adam优化器
2 前向计算临时Tensor显存占用
2.1 self-attention显存占用
这部分Tensor的大小和模型的每一层结构形状有关必须根据具体模型的每层形状来计算也和具体的batch_size大小以及输入数据input_data的大小有关。
输入矩阵I:首先计算 Q I ∗ W q Q I * W^{q} QI∗Wq K I ∗ W k K I * W^{k} KI∗Wk V I ∗ W v V I * W^{v} VI∗Wv输入I是临时Tensor假设输入I的形状为 [b, s, d]元素个数为 bsd占用显存大小为2bytes*bsd2bsd bytes. Q K T QK^{T} QKTQ和K是临时Tensor假设形状为 [b, s, d]元素个数为 bsd占用显存大小为22bytesbsd4bsd bytes。softmax A Q K T AQK^{T} AQKT输入形状[b, h, s, d] × [b, h, s, d]A矩阵输出形状为 [b, h, s, s]h是头个数。保存A矩阵占用的显存大小为2bytes* b h s 2 bhs^{2} bhs2 2 b h s 2 2bhs^{2} 2bhs2 bytes。dropout:需要保存一个mask矩阵mask矩阵的形状与A相同mask矩阵的元素为0或1用1个byte表示占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。score* V加权score矩阵的形状与A相同占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d]占用显存大小为2bytes*bsd2bsd bytes。该步骤占用显存大小为 2 b h s 2 2 b s d 2bhs^{2}2bsd 2bhs22bsd bytes。 W O W^{O} WO输出映射需要临时保存输入矩阵形状[b, s, d]占用显存大小为2bytes*bsd2bsd bytes。dropout需要保存一个mask矩阵mask矩阵的形状为上一步输出形状[b, s, d]mask矩阵的元素为0或1用1个byte表示占用显存大小为1bytes*bsdbsd bytes。 综上步骤self-attention块的占用显存大小为2bsd4bsd 2 b h s 2 2bhs^{2} 2bhs2 2 b h s 2 2bhs^{2} 2bhs2 2 b h s 2 2 b s d 2bhs^{2}2bsd 2bhs22bsd2bsd2bsd11bsd 5 b h s 2 5bhs^{2} 5bhs2
2.2 MLP显存占用
第一个线性层需要保存其输入输入形状为[b, s, d]占用显存大小为 2bytes*bsd2bsd bytes。激活函数需要保存其输入为第一步的输出形状为[b, s, 4d]占用显存大小为2bytes*4bsd8bsd bytes。第二个线性层需要保存其输入输入形状为[b, s, 4d]占用显存大小为2bytes*4bsd8bsd bytes。最后有一个dropout操作需要保存mask矩阵形状是上一步的输出形状[b, s, d]mask矩阵的元素为0或1用1个byte表示占用显存大小为1bytes*bsdbsd bytes。
综上步骤MLP的占用显存大小为2bsd8bsd8bsdbsd19bsd.
3 梯度和优化器显存占用
3.1 模型训练过程两者显存占用
参数占用显存 参数数目 × n n 2 : float16 n 4 : float32 n 8 : double64 其中float32是最常用的类型n是数据类型占用的bytes。 训练过程通常为模型参数前向传播反向传播计算梯度优化器更新以Adam优化器为例分析假如模型参数量为P
混合精度训练 1使用float16的模型参数进行前向传递和反向传播计算得到float16的梯度 2在优化器更新模型参数时使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。 3对于每个可训练模型参数模型参数在步骤1和步骤2分别是2bytes4bytes梯度在步骤1和步骤2分别是分别是2bytes4bytes优化器状态是2* 模型大小2*4bytes8bytes。
每个参数占用24248 20bytes。模型参数量M时总计20P bytes。
普通训练 上述步骤12均使用float32类型。对于每个可训练模型参数模型参数在步骤1和步骤2分别是4bytes4bytes梯度在步骤1和步骤2分别是分别是4bytes4bytes优化器状态是2* 模型大小2*4bytes8bytes。
每个参数占用44448 24bytes模型参数量M时总计24P bytes。
3.2 模型推理过程两者显存占用
推理占用显存主要是模型参数假如模型参数量为P使用float16来进行推理推理阶段模型参数占用的显存约2P bytes使用float32来进行推理推理阶段模型参数占用的显存约 4P bytes。
参考文章https://zhuanlan.zhihu.com/p/624740065?utm_id0