网站免费诊断,产品展示的手机网站,wordpress模板不一样,北京平面设计公司目录 1 引言2 操作步骤和公式说明2.1 准备教师模型#xff08;Teacher Model#xff09;和学生模型#xff08;Student Model#xff09;2.2 生成软标签#xff08;Soft Labels#xff09;2.3 定义蒸馏损失函数2.4 训练学生模型2.5 调整超参数2.6 评估与部署 3 其他知识蒸… 目录 1 引言2 操作步骤和公式说明2.1 准备教师模型Teacher Model和学生模型Student Model2.2 生成软标签Soft Labels2.3 定义蒸馏损失函数2.4 训练学生模型2.5 调整超参数2.6 评估与部署 3 其他知识蒸馏技术4 实践参考文献 1 引言
近年来随着Transformer、MOE架构的提出使得深度学习模型轻松突破上万亿规模参数从而导致模型变得越来越大因此我们需要一些大模型压缩技术来降低模型部署的成本并提升模型的推理性能。而大模型压缩主要分为如下几类剪枝Pruning、知识蒸馏Knowledge Distillation、量化Quantization、低秩分解Low-Rank Factorization。
模型压缩方法分类
技术概述图
大规模语言模型LLM近年来在自然语言处理领域取得了巨大进步使得人类对话和文本生成成为可能。然而开源LLM模型由于参数规模较小性能难以达到商业LLM的水平。知识蒸馏技术可以解决这一问题它通过利用商业LLM的高性能将其知识“蒸馏”Knowledge Distillation知识蒸馏简称KD到更小的开源模型中从而实现高性能和低成本。 模型蒸馏Model Distillation最初由Hinton等人于2015年在论文《Distilling the Knowledge in a Neural Network》提出其核心思想是通过知识迁移的方式将一个复杂的大模型教师模型的知识传授给一个相对简单的小模型学生模型简单概括就是利用教师模型的预测概率分布作为软标签对学生模型进行训练从而在保持较高预测性能的同时极大地降低了模型的复杂性和计算资源需求实现模型的轻量化和高效化。
下面是模型蒸馏的要点
首先需要训练一个大的模型这个大模型也称为 teacher 模型。利用 teacher 模型输出的概率分布训练小模型小模型称为 student 模型。训练 student 模型时包含两种 labelsoft label 对应了 teacher 模型输出的概率分布而 hard label 是原来的 one-hot label。模型蒸馏训练的小模型会学习到大模型的表现以及泛化能力。
2 操作步骤和公式说明
2.1 准备教师模型Teacher Model和学生模型Student Model
教师模型已经训练好的高性能大型模型如BERT、GPT等。学生模型结构更简单的小型模型如TinyBERT、DistilBERT等参数量远小于教师模型。
公式说明 假设教师模型的输出概率分布为 pt学生模型的输出概率分布为 ps。
2.2 生成软标签Soft Labels
操作 对训练数据中的每个样本 x用教师模型计算其输出概率分布软标签。 通过引入温度参数T 平滑概率分布使类别间的关系更明显。
公式说明 教师模型的软标签计算 其中
zt 是教师模型的原始输出logitsT是温度参数通常 T 1如 T3。
高温T使概率分布更平滑学生模型能学习到类别间的隐含关系例如“猫”和“狗”的相似性硬标签One-hot编码仅包含0/1信息而软标签包含更多知识。
2.3 定义蒸馏损失函数
操作 设计总损失函数结合蒸馏损失模仿教师模型和任务损失拟合真实标签。
公式说明
蒸馏损失KL散度 其中
zs是学生模型的logits乘以 T2 是为了平衡温度缩放对梯度的影响。
任务损失交叉熵 其中
是学生模型的原始概率分布。
总损失
其中 α 是蒸馏损失的权重取值范围通常是 [0.5, 0.9]。
蒸馏损失强制学生模仿教师的概率分布任务损失确保学生模型不偏离真实标签温度T和权重 α需调参以平衡两者。
2.4 训练学生模型
操作 使用教师生成的软标签和真实标签联合训练学生模型优化总损失 Ltotal。 训练时需注意 温度参数T训练阶段使用高温如T3推理阶段恢复T1。梯度更新同时优化学生对教师分布和真实标签的拟合。
公式说明 反向传播时总损失的梯度计算为
其中 θs是学生模型的参数。
注意训练时高温 ( T ) 增强知识迁移推理时恢复标准概率分布。
2.5 调整超参数
温度T 较高的T如3~10增强软标签的平滑性适合复杂任务。较低的T如1~2贴近原始分布适合简单任务。 损失权重 α 若教师模型质量高可增大 α如0.7~0.9。若真实标签噪声小可增大任务损失权重。
示例调参策略
两阶段训练 第一阶段高T和大α专注于学习教师知识。第二阶段逐渐降低T和α贴近真实任务。
2.6 评估与部署
评估指标 学生模型在测试集上的准确率、F1值等任务指标。计算学生模型与教师模型的输出相似性如KL散度。 部署 学生模型以T1运行直接输出原始概率分布 psraw。
核心思想 通过教师模型的软标签富含类别间关系和学生模型的任务损失保留真实标签信息蒸馏实现了知识的迁移。
温度T 是核心超参数控制知识迁移的“清晰度”。两阶段训练先学教师再微调是常见优化策略。
3 其他知识蒸馏技术 4 实践
以下是一个简单的模型蒸馏代码示例使用一个预训练的ResNet-18模型作为教师模型并使用一个简单的CNN模型作为学生模型。同时将使用交叉熵损失函数和L2正则化项来优化学生模型的性能表现。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms# 定义教师模型和学生模型
teacher_model models.resnet18(pretrainedTrue)
student_model nn.Sequential(nn.Conv2d(3, 64, kernel_size3, stride1, padding1),nn.ReLU(),nn.MaxPool2d(kernel_size2, stride2),nn.Conv2d(64, 128, kernel_size3, stride1, padding1),nn.ReLU(),nn.MaxPool2d(kernel_size2, stride2),nn.Flatten(),nn.Linear(128 * 7 * 7, 10)
)# 定义损失函数和优化器
criterion nn.CrossEntropyLoss()
optimizer_teacher optim.SGD(teacher_model.parameters(), lr0.01, momentum0.9)
optimizer_student optim.Adam(student_model.parameters(), lr0.001)# 训练数据集
transform transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
trainset datasets.MNIST(../data, trainTrue, downloadTrue, transformtransform)
trainloader torch.utils.data.DataLoader(trainset, batch_size64, shuffleTrue)# 蒸馏过程
for epoch in range(10):running_loss_teacher 0.0running_loss_student 0.0for inputs, labels in trainloader:# 教师模型的前向传播outputs_teacher teacher_model(inputs)loss_teacher criterion(outputs_teacher, labels)running_loss_teacher loss_teacher.item()# 学生模型的前向传播outputs_student student_model(inputs)loss_student criterion(outputs_student, labels) 0.1 * torch.sum((outputs_teacher - outputs_student) ** 2)running_loss_student loss_student.item()# 反向传播和参数更新optimizer_teacher.zero_grad()optimizer_student.zero_grad()loss_teacher.backward()optimizer_teacher.step()loss_student.backward()optimizer_student.step()print(fEpoch {epoch1}/10 \t Loss Teacher: {running_loss_teacher / len(trainloader)} \t Loss Student: {running_loss_student / len(trainloader)})
在这个示例中 1首先定义了教师模型和学生模型并初始化了相应的损失函数和优化器 2然后加载了MNIST手写数字数据集并对其进行了预处理 3接下来进入蒸馏过程对于每个批次的数据首先使用教师模型进行前向传播并计算损失函数值然后使用学生模型进行前向传播并计算损失函数值同时加入了L2正则化项以鼓励学生模型学习教师模型的输出 4最后对损失函数值进行反向传播和参数更新打印了每个批次的损失函数值以及每个epoch的平均损失函数值。 通过多次迭代训练后我们可以得到一个性能较好且轻量化的学生模型。
参考文献
A Survey on Knowledge Distillation of Large Language ModelsA Survey on Model Compression for Large Language ModelsBERT模型蒸馏指南知乎