DataLoader 与 DataSet
DataLoader 与 DataSet
PyTorch 数据读入是通过 Dataset+DataLoader 的方式完成的,Dataset 定义好数据的格式和数据变换形式,DataLoader 用 iterative 的方式不断读入批次数据。
PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。
数据模块又可以细分为 4 个部分:
- 数据收集:样本和标签。
- 数据划分:训练集、验证集和测试集
- 数据读取:对应于 PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
- 数据预处理:对应于 PyTorch 的 transforms
Dataloader
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
1 | |
参数
- dataset (Dataset) – 加载数据的数据集,Dataset 类,决定数据从哪里读取以及如何读取。
- batch_size (int, optional) – 每个 batch 加载多少个样本(默认: 1)。
- shuffle (bool, optional) – 设置为
True时会在每个 epoch 重新打乱数据(默认: False). - sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略
shuffle参数。 - num_workers (int, optional) – 用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
- collate_fn (callable, optional) –
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的 batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个 batch 将更小。(默认: False)
DataLoader的使用方法示例:
1 | |
数据迭代器的返回结果如下:
1 | |
注意:
len(dataset) = 数据集的样本数len(data_loader) = math.ceil(样本数/batch_size) 即向上取整
Dataset
torch.utils.data.Dataset是 PyTorch 中的一个抽象类,用于表示一个数据集。当你想要创建自己的数据集时,你需要继承这个类并实现至少两个方法:__len__和__getitem__。
参数说明
torch.utils.data.Dataset 本身并没有直接的参数,因为它是一个抽象基类,需要子类实现具体的方法。
需要实现的方法
__len__(self):- 返回数据集的大小(即数据项的总数)。
- 当你使用
len(dataset)时,这个方法会被调用。
__getitem__(self, index):- 根据提供的索引返回单个数据项。
- 当你使用
dataset[index]时,这个方法会被调用。 - 返回的数据项通常是一个元组,包含输入数据和标签(如果有的话)。
示例
1 | |
其他常见方法(可选实现)
虽然 __len__ 和 __getitem__ 是必需的,但你可能还想实现其他方法以提供更多的功能,例如:
__init__: 用于向类中传入外部参数,同时定义样本集__add__(self, other): 实现数据集的加法操作,使得你可以合并两个数据集。transform(self, fn): 定义一个转换方法,该方法接受一个函数fn并返回一个新的数据集,其中每个数据项都经过fn的处理。
注意事项
- 当创建自定义数据集时,确保你的
__getitem__方法返回的数据类型与你的模型期望的输入类型相匹配。 - 如果你想要进行批量处理,可以考虑使用
torch.utils.data.DataLoader,它可以与你的Dataset子类一起使用。
PyTorch 数据读取流程图
首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调DatasetFetcher
根据index获取数据。在DatasetFetcher里会调用Dataset的__getitem__()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser。
通过 DataLoader 获取数据
1 | |
其中 (_, inputs, targets)与DummyDataset中的__getitem__返回值对应
1 | |
torchvision
torchvision 是独立于 pytorch 的关于图像操作的一些方便工具库。
torchvision 的详细介绍在:https://pypi.org/project/torchvision/
torchvision 主要包括以下几个包:
- vision.datasets : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的 Dataset 的子类
- vision.models : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
- vision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到 tensor ,numpy 数组到 tensor , tensor 到 图像等。
- vision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个 mini-batch 的图像可以产生一个图像格网。
pytorch 自带的数据集
pytorch 中自带的数据集由两个上层 api 提供,分别是torchvision和torchtext
其中:
torchvision提供了对图片数据处理相关的 api 和数据- 数据位置:
torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
- 数据位置:
torchtext提供了对文本数据处理相关的 API 和数据- 数据位置:
torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)
- 数据位置:
下面我们以 Mnist 手写数字为例,来看看 pytorch 如何加载其中自带的数据集
使用方法和之前一样:
- 准备好 Dataset 实例
- 把 dataset 交给 dataloder 打乱顺序,组成 batch
torchversion.datasets
torchversoin.datasets中的数据集类(比如torchvision.datasets.MNIST),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例
但是 MNIST API 中的参数需要注意一下:
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)
root参数表示数据存放的位置train:bool 类型,表示是使用训练集的数据还是测试集的数据download:bool 类型,表示是否需要下载数据到 root 目录transform:实现的对图片的处理函数
MNIST 数据集的介绍
数据集的原始地址:http://yann.lecun.com/exdb/mnist/
MNIST 是由Yann LeCun等人提供的免费的图像识别的数据集,其中包括 60000 个训练样本和 10000 个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28
执行代码,下载数据,观察数据类型:
1 | |
下载的数据如下:
代码输出结果如下:
1 | |
可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值
返回值的第 0 个为 Image 类型,可以调用 show() 方法打开,发现为手写数字 5
1 | |
图片如下:
由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到。







