网站推广软件有哪些,免费的舆情网站不需下载,如何优化关键词的排名,提供服务好的网站建设Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容#xff0c;它相当于一个类似列表的数据结构#xff0c;具有确定的长度#xff0c;能够用索引获取数据集中的元素。 而DataLoader定义了按batch加载数据集的方法#xff0c;它是…Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容它相当于一个类似列表的数据结构具有确定的长度能够用索引获取数据集中的元素。 而DataLoader定义了按batch加载数据集的方法它是一个实现了**iter**方法的可迭代对象每次迭代输出一个batch的数据。 DataLoader能够控制batch的大小batch中元素的采样方法以及将batch结果整理成模型所需输入形式的方法collate_fn并且能够使用多进程读取数据。 在绝大部分情况下用户只需实现Dataset的__len__方法和__getitem__方法就可以轻松构建自己的数据集并用默认数据管道进行加载。
一、深入理解Dataset和DataLoader的原理
1. 获取一个batch数据的步骤
让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。 (假定数据集的特征和标签分别表示为张量X和Y数据集可以表示为(X,Y), 假定batch大小为m) 1首先我们要确定数据集的长度n。 结果类似n 1000。 2然后我们从0到n-1的范围中抽样出m个数(batch大小)。 假定m4, 拿到的结果是一个列表类似indices [1,4,8,9] 3接着我们从数据集中去取这m个数对应下标的元素。 拿到的结果是一个元组列表类似samples [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])] 4最后我们将结果整理成两个张量作为输出。 拿到的结果是两个张量类似batch (features,labels) 其中 features torch.stack([X[1],X[4],X[8],X[9]]) labels torch.stack([Y[1],Y[4],Y[8],Y[9]])
2.Dataset和DataLoader的功能分工
上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。 第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader 的 sampler 和 batch_sampler参数指定的。 sampler参数指定单个元素抽样方法一般无需用户设置程序默认在DataLoader的参数shuffleTrue时采用随机抽样shuffleFalse时采用顺序抽样。 batch_sampler参数将多个抽样的元素整理成一个列表一般无需用户设置默认方法在DataLoader的参数drop_lastTrue时会丢弃数据集最后一个长度不能被batch大小整除的批次在drop_lastFalse时保留最后一个批次。 第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 getitem方法实现的。 第4个步骤的逻辑由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoader
from torch.utils.data import RandomSampler,BatchSampler ds TensorDataset(torch.randn(1000,3),torch.randint(low0,high2,size(1000,)).float())
dl DataLoader(ds,batch_size4,drop_last False)
features,labels next(iter(dl))
print(features ,features )
print(labels ,labels ) # step1: 确定数据集长度 (Dataset的 __len__ 方法实现)
ds TensorDataset(torch.randn(1000,3),torch.randint(low0,high2,size(1000,)).float())
print(n , len(ds)) # len(ds)等价于 ds.__len__()# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)
sampler RandomSampler(data_source ds)
batch_sampler BatchSampler(sampler sampler, batch_size 4, drop_last False)
for idxs in batch_sampler:indices idxsbreak
print(indices ,indices)# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)
batch [ds[i] for i in indices] # ds[i] 等价于 ds.__getitem__(i)
print(batch , batch)# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)
def collate_fn(batch):features torch.stack([sample[0] for sample in batch]) # torch.stack是一个torch库中的函数用于沿着指定的维度对输入的张量序列进行堆叠即堆叠张量labels torch.stack([sample[1] for sample in batch])return features,labels features,labels collate_fn(batch)
print(features ,features)
print(labels ,labels) 3.Dataset和DataLoader的核心源码
import torch
class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset, batch_size, collate_fn None, shuffle True, drop_last False):self.dataset datasetself.collate_fn collate_fnself.sampler torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler torch.utils.data.BatchSamplerself.sample_iter self.batch_sampler(self.sampler(self.dataset),batch_size batch_size,drop_last drop_last)self.collate_fn collate_fn if collate_fn is not None else \torch.utils.data._utils.collate.default_collatedef __next__(self):indices next(iter(self.sample_iter))batch self.collate_fn([self.dataset[i] for i in indices])return batchdef __iter__(self):return self
对源码进行测试
class ToyDataset(Dataset):def __init__(self,X,Y):self.X Xself.Y Y def __len__(self):return len(self.X)def __getitem__(self,index):return self.X[index],self.Y[index]X,Y torch.randn(1000,3),torch.randint(low0,high2,size(1000,)).float()
ds ToyDataset(X,Y)dl DataLoader(ds,batch_size4,drop_last False)
features,labels next(iter(dl))
print(features ,features )
print(labels ,labels ) 二、使用Dataset创建数据集
Dataset创建数据集常用的方法有
使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的arrayPandas的DataFrame需要先转换成Tensor)。使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。继承 torch.utils.data.Dataset 创建自定义数据集。
此外还可以通过
torch.utils.data.random_split 将一个数据集分割成多份常用于分割训练集验证集和测试集。调用Dataset的加法运算符()将多个数据集合并成一个数据集。
根据Tensor创建数据集
创建数据集
# 根据Tensor创建数据集from sklearn import datasets
iris datasets.load_iris()
ds_iris TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))# 分割成训练集和预测集
n_train int(len(ds_iris)*0.8)
n_val len(ds_iris) - n_train
ds_train,ds_val random_split(ds_iris,[n_train,n_val])print(type(ds_iris))
print(type(ds_train))加载数据集
# 使用DataLoader加载数据集
dl_train,dl_val DataLoader(ds_train,batch_size 8),DataLoader(ds_val,batch_size 8)for features,labels in dl_train:print(features,labels)break演示加法运算符的合并作用
# 演示加法运算符的合并作用ds_data ds_train ds_valprint(len(ds_train) ,len(ds_train))
print(len(ds_valid) ,len(ds_val))
print(len(ds_trainds_valid) ,len(ds_data))print(type(ds_data))根据图片目录创建图片数据集
先定义图片增强操作
# 定义图片增强操作transform_train transforms.Compose([transforms.RandomHorizontalFlip(), #随机水平翻转transforms.RandomVerticalFlip(), #随机垂直翻转transforms.RandomRotation(45), #随机在45度角度内旋转transforms.ToTensor() #转换成张量]
) transform_valid transforms.Compose([transforms.ToTensor()]
)根据图片目录创建数据集
# 根据图片目录创建数据集def transform_label(x):return torch.tensor([x]).float()ds_train datasets.ImageFolder(./eat_pytorch_datasets/cifar2/train/,transform transform_train,target_transform transform_label)
ds_val datasets.ImageFolder(./eat_pytorch_datasets/cifar2/test/,transform transform_valid,target_transform transform_label)print(ds_train.class_to_idx)# 使用DataLoader加载数据集dl_train DataLoader(ds_train,batch_size 50,shuffle True)
dl_val DataLoader(ds_val,batch_size 50,shuffle True)for features,labels in dl_train:print(features.shape)print(labels.shape)break创建自定义数据集
下面我们通过另外一种方式即继承 torch.utils.data.Dataset 创建自定义数据集的方式来对 cifar2构建 数据管道。
from pathlib import Path
from PIL import Image class Cifar2Dataset(Dataset): # 继承torch.utils.data.Datasetdef __init__(self,imgs_dir, img_transform):self.files list(Path(imgs_dir).rglob(*.jpg))self.transform img_transformdef __len__(self,):return len(self.files)def __getitem__(self,i):file_i str(self.files[i])img Image.open(file_i)tensor self.transform(img)label torch.tensor([1.0]) if 1_automobile in file_i else torch.tensor([0.0])return tensor,label train_dir ./eat_pytorch_datasets/cifar2/train/
test_dir ./eat_pytorch_datasets/cifar2/test/使用
# 定义图片增强
transform_train transforms.Compose([transforms.RandomHorizontalFlip(), #随机水平翻转transforms.RandomVerticalFlip(), #随机垂直翻转transforms.RandomRotation(45), #随机在45度角度内旋转transforms.ToTensor() #转换成张量]
) transform_val transforms.Compose([transforms.ToTensor()]
)ds_train Cifar2Dataset(train_dir,transform_train)
ds_val Cifar2Dataset(test_dir,transform_val)dl_train DataLoader(ds_train,batch_size 50,shuffle True)
dl_val DataLoader(ds_val,batch_size 50,shuffle True)for features,labels in dl_train:print(features.shape)print(labels.shape)break三、使用DataLoader加载数据集
DataLoader能够控制batch的大小batch中元素的采样方法随机否以及将batch结果整理成模型所需输入形式的方法collate_fn并且能够使用多进程读取数据。 DataLoader的函数签名如下。
DataLoader(dataset,batch_size1,shuffleFalse,samplerNone,batch_samplerNone,num_workers0,collate_fnNone,pin_memoryFalse,drop_lastFalse,timeout0,worker_init_fnNone,multiprocessing_contextNone,
)一般情况下我们仅仅会配置 dataset, batch_size, shuffle, num_workers, pin_memory, drop_last这六个参数 有时候对于一些复杂结构的数据集还需要自定义collate_fn函数其他参数一般使用默认值即可。 DataLoader除了可以加载我们前面讲的 torch.utils.data.Dataset 外还能够加载另外一种数据集 torch.utils.data.IterableDataset。 和Dataset数据集相当于一种列表结构不同IterableDataset相当于一种迭代器结构。 它更加复杂一般较少使用。
dataset : 数据集batch_size: 批次大小shuffle: 是否乱序sampler: 样本采样函数一般无需设置。batch_sampler: 批次采样函数一般无需设置。num_workers: 使用多进程读取数据设置的进程数。collate_fn: 整理一个批次数据的函数。pin_memory: 是否设置为锁业内存。默认为False锁业内存不会使用虚拟内存(硬盘)从锁业内存拷贝到GPU上速度会更快。drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。timeout: 加载一个数据批次的最长等待时间一般无需设置。worker_init_fn: 每个worker中dataset的初始化函数常用于 IterableDataset。一般不使用。
#构建输入数据管道
ds TensorDataset(torch.arange(1,50))
dl DataLoader(ds,batch_size 10,shuffle True,num_workers2,drop_last True)
#迭代数据
for batch, in dl:print(batch)参考https://github.com/lyhue1991/eat_pytorch_in_20_days