深圳设计网站源码,网站定制开发加公众号,网站建设管理教程视频教程,蓝色phpcms律师网站模板phpcms律师文章目录 1. 导入相关库2. 加载数据集3. 整理数据集4. 图像增广5. 读取数据6. 微调预训练模型7. 定义损失函数和评价损失函数9. 训练模型 1. 导入相关库
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l2. 加载数据集
- 该数据… 文章目录 1. 导入相关库2. 加载数据集3. 整理数据集4. 图像增广5. 读取数据6. 微调预训练模型7. 定义损失函数和评价损失函数9. 训练模型 1. 导入相关库
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l2. 加载数据集
- 该数据集是完整数据集的小规模样本# 下载数据集
d2l.DATA_HUB[dog_tiny] (d2l.DATA_URL kaggle_dog_tiny.zip,0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d)# 如果使用Kaggle比赛的完整数据集请将下面的变量更改为False
demo True
if demo:data_dir d2l.download_extract(dog_tiny)
else:data_dir os.path.join(.., data, dog-breed-identification)3. 整理数据集
def reorg_dog_data(data_dir, valid_ratio):labels d2l.read_csv_labels(os.path.join(data_dir, labels.csv))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size 32 if demo else 128
valid_ratio 0.1
reorg_dog_data(data_dir, valid_ratio)4. 图像增广
transform_train torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale(0.08, 1.0), ratio(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])5. 读取数据
train_ds, train_valid_ds [torchvision.datasets.ImageFolder(os.path.join(data_dir, train_valid_test, folder),transformtransform_train) for folder in [train, train_valid]
]
valid_ds, test_ds [torchvision.datasets.ImageFolder(os.path.join(data_dir, train_valid_test, folder),transformtransform_test) for folder in [valid, test]
]train_iter, train_valid_iter [torch.utils.data.DataLoader(dataset, batch_size, shuffleTrue, drop_lastTrue) for dataset in (train_ds, train_valid_ds)
]
valid_iter torch.utils.data.DataLoader(valid_ds, batch_size, shuffleFalse, drop_lastTrue
)
test_iter torch.utils.data.DataLoader(test_ds, batch_size, shuffleFalse, drop_lastTrue
)6. 微调预训练模型
def get_net(devices):finetune_net nn.Sequential()finetune_net.features torchvision.models.resnet34(weightstorchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络共有120个输出类别finetune_net.output_new nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad Falsereturn finetune_net# 查看网络模型
get_net(devicesd2l.try_all_gpus())7. 定义损失函数和评价损失函数
# 定义损失函数
loss nn.CrossEntropyLoss(reductionnone)def evaluate_loss(data_iter, net, device):l_sum, n 0.0, 0for features, labels in data_iter:features, labels features.to(device[0]), labels.to(device[0])outputs net(features)l loss(outputs, labels)l_sum l.sum()n labels.numel()return (l_sum / n).to(cpu)定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net nn.DataParallel(net, device_idsdevices).to(devices[0])trainer torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lrlr, momentum0.9, weight_decaywd)scheduler torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer len(train_iter), d2l.Timer()legend [train loss]if valid_iter is not None:legend.append(valid loss)animator d2l.Animator(xlabelepoch, xlim[1, num_epochs], legendlegend)for epoch in range(num_epochs):metric d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output net(features)l loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i 1) % (num_batches // 5) 0 or i num_batches - 1:animator.add(epoch (i 1) / num_batches, (metric[0] / metric[1], None))measures ftrain loss {metric[0] / metric[1]:.3f}if valid_iter is not None :valid_loss evaluate_loss(valid_iter, net, devices)animator.add(epoch 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures f, valid loss {valid_loss:.3f}print(measures f\n{metric[1] * num_epochs / timer.sum():.1f}fexamples/sec on {str(devices)})9. 训练模型
devices, num_epochs, lr, wd d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, 2, 0.9, get_net(devices)import time# 在开头设置开始时间
start time.perf_counter() # start time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end time.perf_counter() # end time.clock() python3.8之前可以# 再将其进行打印即可显示出程序完成的运行耗时
print(f运行耗时{(end-start):.4f})