山西省住房和城乡建设部网站,wordpress 数据库同步,新网网站建设,天河做网站开发最近需要训练一个有200多类的图片分类网络#xff0c;搜了一遍#xff0c;发现居然没有很合适用的开源项目#xff0c;于是自己简单撸了一个轮子#xff0c;项目地址: https://github.com/xuduo35/imgcls_pytorch。支持如下backbone:
alexnetresnet18,resnet34,resnet50,r… 最近需要训练一个有200多类的图片分类网络搜了一遍发现居然没有很合适用的开源项目于是自己简单撸了一个轮子项目地址: https://github.com/xuduo35/imgcls_pytorch。支持如下backbone:
alexnetresnet18,resnet34,resnet50,resnet101, resnet152, resnext101_32x4d, resnext101_64x4dvgg11_bn, vgg16_bndensenet121, densenet169, densenet161inceptionv3, inceptionv4, inceptionresnetv2, bninceptionxception, xception_attdpn98, dpn107, dpn131senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4dpnasnet5largepolynetefficientnet 使用简便第一步是按如下格式准备数据集
your_dataset_directory class1 1.jpg2.jpgclass2 1.jpg2.jpg...... 自定义一个Dataset实现如下
class ImageFolderEx(Dataset):def __init__(self, image_dir, image_files, image_labels, classnum1000, transformNone):self.image_dir image_dirself.image_files image_filesself.image_labels image_labelsself.classnum classnumself.transform transformdef __len__(self):return len(self.image_files)def __getitem__(self, index):image_name os.path.join(self.image_dir, self.image_files[index]) image cv2.imread(image_name)image image[:,:,::-1]image Image.fromarray(image)label self.image_labels[index]if self.transform:image self.transform(image)onehot [0]*self.classnumonehot[label] 1return (image, np.array(onehot).astype(np.float32)) 支持简单的余弦退火学习率调度器 scheduler optim.lr_scheduler.CosineAnnealingLR(optimizer,T_maxargs.epochs, eta_min0.00001, last_epoch-1) 一方面执行简单的数据增广 transforms.Compose([transforms.RandomRotation([-13,13]),transforms.ColorJitter(brightness0.5, contrast0.5, saturation0.5),transforms.Resize(args.imgsz32),transforms.RandomCrop(args.imgsz),transforms.ToTensor(),normalize]) 另外再按照一定比例执行cutmix和mixup增广 cutmix v2.CutMix(num_classeslen(classes))mixup v2.MixUp(num_classeslen(classes))cutmix_or_mixup v2.RandomChoice([cutmix, mixup]) 训练命令example CUDA_VISIBLE_DEVICES0 python3 -u train.py --backbone resnet101 --classnum 270 --workers 32 --lr0.001 --epochs 30 --train_bs 160 --datadir your_dataset_directory 简单的基于gradio的demo app.py模型路径要简单调整一下 CUDA_VISIBLE_DEVICES0 PORT8000 python3 -u app.py 如果训练过程需要tensorboard显示中文可以执行如下python指令然后执行输出内容里的命令 python3 fixfont.py /home/ubuntu/torch19/lib/python3.10/site-packages/matplotlib/mpl-data/matplotlibrc /home/ubuntu/.cache/matplotlib wget https://github.com/StellarCN/scp_zh/raw/master/fonts/SimHei.ttf rm -f /home/ubuntu/.cache/matplotlib/* cp ./SimHei.ttf /home/ubuntu/torch19/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf