简单网站建设优化,公司备案证查询网站查询,开发网站的语言,辽宁省建设厅安全员考试官方网站早停法#xff08;Early Stopping#xff09;是一种用于防止模型过拟合的技术#xff0c;在训练过程中监视验证集#xff08;或者测试集#xff09;上的损失值。具体设立早停的限制包括两个主要参数#xff1a; Patience#xff08;耐心#xff09;#xff1a;这是指验…早停法Early Stopping是一种用于防止模型过拟合的技术在训练过程中监视验证集或者测试集上的损失值。具体设立早停的限制包括两个主要参数 Patience耐心这是指验证集损失在连续多少个epoch没有显著改善时才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时表示模型可能已经过拟合或者陷入局部最优点这时候早停就会被触发。 Best Loss最佳损失这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时更新最佳损失并重置耐心计数器。如果验证集损失连续不降耐心计数器超过设定的耐心值时早停就会被触发训练过程停止。 早停的具体设立是基于验证集上的损失值 val_loss。每次验证后如果当前的 val_loss 比 best_loss 还要低就更新 best_loss 并重置 patience_counter否则增加 patience_counter。当 patience_counter 达到设定的 patience 值时早停被触发即停止训练过程以防止模型过拟合。 总结来说早停的设立限制是基于耐心参数和最佳损失值用来判断模型是否应该停止训练以避免过拟合。
# 训练模型
num_epochs 200 # 总的训练轮数
best_loss float(inf) # 初始化最佳验证损失为正无穷大
patience 10 # 早停的耐心值
patience_counter 0 # 耐心计数器for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad() # 梯度清零outputs model(geno) # 前向传播loss criterion(outputs.squeeze(), pheno) # 计算损失loss.backward() # 反向传播optimizer.step() # 优化模型参数model.eval()val_loss 0with torch.no_grad(): # 不计算梯度for geno, pheno in test_loader:outputs model(geno) # 前向传播val_loss criterion(outputs.squeeze(), pheno).item() # 计算验证损失val_loss / len(test_loader) # 计算平均验证损失print(fEpoch [{epoch 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f})scheduler.step(val_loss) # 更新学习率# 早停法if val_loss best_loss:best_loss val_loss # 更新最佳验证损失patience_counter 0 # 重置耐心计数器else:patience_counter 1 # 增加耐心计数器if patience_counter patience: # 如果耐心计数器达到设定的耐心值print(Early stopping triggered) # 触发早停breakEarlyStopping 类 __init__ 方法初始化早停的参数如 patience耐心值、verbose是否打印消息和 delta损失改进的最小变化。__call__ 方法根据验证损失来决定是否更新 best_loss以及是否增加计数器或者触发早停。训练循环 训练和验证过程与之前相同。每个epoch结束时调用 early_stopping 对象传入当前的验证损失。检查 early_stopping.early_stop 标志如果为 True则打印消息并停止训练。
通过使用 EarlyStopping 类你可以更简洁和模块化地实现早停功能使代码更易于维护和扩展。
import torch
import numpy as npclass EarlyStopping:def __init__(self, patience10, verboseFalse, delta0):EarlyStopping 初始化.Args:patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.verbose (bool): 如果为True则每次验证集损失改进时会打印一条消息.delta (float): 验证集损失改进的最小变化.self.patience patienceself.verbose verboseself.delta deltaself.best_loss Noneself.counter 0self.early_stop Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss val_losselif val_loss self.best_loss - self.delta:self.counter 1if self.verbose:print(fEarlyStopping counter: {self.counter} out of {self.patience})if self.counter self.patience:self.early_stop Trueelse:self.best_loss val_lossself.counter 0if self.verbose:print(fValidation loss decreased to {self.best_loss:.6f}. Resetting counter.)# 初始化EarlyStopping对象
early_stopping EarlyStopping(patience10, verboseTrue)# 训练模型
num_epochs 200
for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad()outputs model(geno)loss criterion(outputs.squeeze(), pheno)loss.backward()optimizer.step()model.eval()val_loss 0with torch.no_grad():for geno, pheno in test_loader:outputs model(geno)val_loss criterion(outputs.squeeze(), pheno).item()val_loss / len(test_loader)print(fEpoch [{epoch 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f})scheduler.step(val_loss)# 检查是否触发早停early_stopping(val_loss)if early_stopping.early_stop:print(Early stopping triggered)break