当前位置: 首页 > news >正文

深圳设计网站源码网站定制开发加公众号

深圳设计网站源码,网站定制开发加公众号,网站建设管理教程视频教程,蓝色phpcms律师网站模板phpcms律师文章目录 1. 导入相关库2. 加载数据集3. 整理数据集4. 图像增广5. 读取数据6. 微调预训练模型7. 定义损失函数和评价损失函数9. 训练模型 1. 导入相关库 import os import torch import torchvision from torch import nn from d2l import torch as d2l2. 加载数据集 - 该数据… 文章目录 1. 导入相关库2. 加载数据集3. 整理数据集4. 图像增广5. 读取数据6. 微调预训练模型7. 定义损失函数和评价损失函数9. 训练模型 1. 导入相关库 import os import torch import torchvision from torch import nn from d2l import torch as d2l2. 加载数据集 - 该数据集是完整数据集的小规模样本# 下载数据集 d2l.DATA_HUB[dog_tiny] (d2l.DATA_URL kaggle_dog_tiny.zip,0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d)# 如果使用Kaggle比赛的完整数据集请将下面的变量更改为False demo True if demo:data_dir d2l.download_extract(dog_tiny) else:data_dir os.path.join(.., data, dog-breed-identification)3. 整理数据集 def reorg_dog_data(data_dir, valid_ratio):labels d2l.read_csv_labels(os.path.join(data_dir, labels.csv))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size 32 if demo else 128 valid_ratio 0.1 reorg_dog_data(data_dir, valid_ratio)4. 图像增广 transform_train torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale(0.08, 1.0), ratio(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) transform_test torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])5. 读取数据 train_ds, train_valid_ds [torchvision.datasets.ImageFolder(os.path.join(data_dir, train_valid_test, folder),transformtransform_train) for folder in [train, train_valid] ] valid_ds, test_ds [torchvision.datasets.ImageFolder(os.path.join(data_dir, train_valid_test, folder),transformtransform_test) for folder in [valid, test] ]train_iter, train_valid_iter [torch.utils.data.DataLoader(dataset, batch_size, shuffleTrue, drop_lastTrue) for dataset in (train_ds, train_valid_ds) ] valid_iter torch.utils.data.DataLoader(valid_ds, batch_size, shuffleFalse, drop_lastTrue ) test_iter torch.utils.data.DataLoader(test_ds, batch_size, shuffleFalse, drop_lastTrue )6. 微调预训练模型 def get_net(devices):finetune_net nn.Sequential()finetune_net.features torchvision.models.resnet34(weightstorchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络共有120个输出类别finetune_net.output_new nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad Falsereturn finetune_net# 查看网络模型 get_net(devicesd2l.try_all_gpus())7. 定义损失函数和评价损失函数 # 定义损失函数 loss nn.CrossEntropyLoss(reductionnone)def evaluate_loss(data_iter, net, device):l_sum, n 0.0, 0for features, labels in data_iter:features, labels features.to(device[0]), labels.to(device[0])outputs net(features)l loss(outputs, labels)l_sum l.sum()n labels.numel()return (l_sum / n).to(cpu)定义训练函数 def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net nn.DataParallel(net, device_idsdevices).to(devices[0])trainer torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lrlr, momentum0.9, weight_decaywd)scheduler torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer len(train_iter), d2l.Timer()legend [train loss]if valid_iter is not None:legend.append(valid loss)animator d2l.Animator(xlabelepoch, xlim[1, num_epochs], legendlegend)for epoch in range(num_epochs):metric d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output net(features)l loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i 1) % (num_batches // 5) 0 or i num_batches - 1:animator.add(epoch (i 1) / num_batches, (metric[0] / metric[1], None))measures ftrain loss {metric[0] / metric[1]:.3f}if valid_iter is not None :valid_loss evaluate_loss(valid_iter, net, devices)animator.add(epoch 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures f, valid loss {valid_loss:.3f}print(measures f\n{metric[1] * num_epochs / timer.sum():.1f}fexamples/sec on {str(devices)})9. 训练模型 devices, num_epochs, lr, wd d2l.try_all_gpus(), 10, 1e-4, 1e-4 lr_period, lr_decay, net, 2, 0.9, get_net(devices)import time# 在开头设置开始时间 start time.perf_counter() # start time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间 end time.perf_counter() # end time.clock() python3.8之前可以# 再将其进行打印即可显示出程序完成的运行耗时 print(f运行耗时{(end-start):.4f})
http://www.dnsts.com.cn/news/129627.html

相关文章:

  • 网站会员系统制作wordpress 手机号登入
  • 依波手表价格 官方网站私密浏览器怎么看片
  • 做网站龙岗如何做网址
  • 汽车网站制作传媒公司如何注册
  • 网站二级栏目个人网站包含哪些内容
  • 茂名市住房和城乡建设局网站绿色wordpress主题模板
  • 南宁网站开发建设asp.net mvc 网站开发之美 pdf
  • iis 网站关闭网站html地图怎么做的
  • 实验室网站制作购物网站建设个人总结
  • c2c网站免费建设自己做网站赚钱
  • 做ps网页设计的网站有哪些网站关键词抓取
  • 想找工作去哪个网站京津冀协同发展英语
  • 自己做网站服务器的备案方法企业推广图片
  • 网站建设培训要多久西安网站建设淘猫网络
  • 者珠海市建设局网站工作网站开发制作
  • 手机怎么建立自己的网站石家庄网站建设推广服务
  • pc网站生成手机网站wordpress 宅谈
  • 提供秦皇岛网站建设价格黄页app下载
  • vs2010网站开发实例龙华网站建设推广平台
  • 自己做的网站抬头在哪里改北京如何做网站
  • 企业网站建设合同书盖章页设计师工作室网站
  • 鄂尔多斯网站制作 建设网站建设详细合同范本
  • 网站编程沧州大型网站建设
  • 青岛模板建站多少钱百度如何发布作品
  • 一下成都网站建设公司seo网站查询工具
  • php怎么做全网小视频网站网站建设合同报价
  • 农安县住房和城乡建设厅网站台州建设企业网站
  • php网站开发师wordpress不能编辑文章
  • 广州建网站技术网站建设教程网哪个好
  • 如今做啥网站能致富湖北省建设工程招标网站