DataLoader 与 DataSet

PyTorch 数据读入是通过 Dataset+DataLoader 的方式完成的,Dataset 定义好数据的格式和数据变换形式,DataLoader 用 iterative 的方式不断读入批次数据。

torch.utils.data

PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。

数据模块又可以细分为 4 个部分:

  • 数据收集:样本和标签。
  • 数据划分:训练集、验证集和测试集
  • 数据读取:对应于 PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
  • 数据预处理:对应于 PyTorch 的 transforms

image-20220912180853467

Dataloader

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

1
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

https://gitee.com/geeks_z/upload_images/raw/master/202112212156941.png

参数

  • dataset (Dataset) – 加载数据的数据集,Dataset 类,决定数据从哪里读取以及如何读取。
  • batch_size (int, optional) – 每个 batch 加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个 epoch 重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
  • collate_fn (callable, optional) –
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的 batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个 batch 将更小。(默认: False)

DataLoader的使用方法示例:

1
2
3
4
5
6
7
8
9
from torch.utils.data import DataLoader

dataset = CifarDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)

#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
print(index,label,context)
print("*"*100)

数据迭代器的返回结果如下:

1
2
3
555 ('spam', 'ham', 'spam', 'ham', 'ham', 'ham', 'ham', 'spam', 'ham', 'ham') ('URGENT! We are trying to contact U. Todays draw shows that you have won a £800 prize GUARANTEED. Call 09050003091 from....", 'swhrt how u dey,hope ur ok, tot about u 2day.love n miss.take care.')
***********************************************************************************
556 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') ('He telling not to tell any one. If so treat for me hi hi hi', 'Did u got that persons story', "Don kn....1000 cash prize or a prize worth £5000')

注意:

  1. len(dataset) = 数据集的样本数
  2. len(data_loader) = math.ceil(样本数/batch_size) 即向上取整

Dataset

torch.utils.data.Dataset 是 PyTorch 中的一个抽象类,用于表示一个数据集。当你想要创建自己的数据集时,你需要继承这个类并实现至少两个方法:__len____getitem__

参数说明

torch.utils.data.Dataset 本身并没有直接的参数,因为它是一个抽象基类,需要子类实现具体的方法。

需要实现的方法

  1. __len__(self):

    • 返回数据集的大小(即数据项的总数)。
    • 当你使用 len(dataset) 时,这个方法会被调用。
  2. __getitem__(self, index):

    • 根据提供的索引返回单个数据项。
    • 当你使用 dataset[index] 时,这个方法会被调用。
    • 返回的数据项通常是一个元组,包含输入数据和标签(如果有的话)。

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch.utils.data import Dataset

class MyDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

# 使用示例
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 获取数据集的大小
print(len(dataset)) # 输出: 5

# 获取索引为2的数据项
print(dataset[2]) # 输出: 3

其他常见方法(可选实现)

虽然 __len____getitem__ 是必需的,但你可能还想实现其他方法以提供更多的功能,例如:

  • __init__: 用于向类中传入外部参数,同时定义样本集
  • __add__(self, other): 实现数据集的加法操作,使得你可以合并两个数据集。
  • transform(self, fn): 定义一个转换方法,该方法接受一个函数 fn 并返回一个新的数据集,其中每个数据项都经过 fn 的处理。

注意事项

  • 当创建自定义数据集时,确保你的 __getitem__ 方法返回的数据类型与你的模型期望的输入类型相匹配。
  • 如果你想要进行批量处理,可以考虑使用 torch.utils.data.DataLoader,它可以与你的 Dataset 子类一起使用。

PyTorch 数据读取流程图

Untitled

首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调DatasetFetcher
根据index获取数据。在DatasetFetcher里会调用Dataset__getitem__()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser

通过 DataLoader 获取数据

1
for i, (_, inputs, targets) in enumerate(sample_loader):

其中 (_, inputs, targets)DummyDataset中的__getitem__返回值对应

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class DummyDataset(Dataset):
def __init__(self, images, labels, trsf, use_path=False):
……

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
if self.use_path:
image = self.trsf(pil_loader(self.images[idx]))
else:
image = self.trsf(Image.fromarray(self.images[idx]))
label = self.labels[idx]

return idx, image, label

torchvision

torchvision 是独立于 pytorch 的关于图像操作的一些方便工具库。

torchvision 的详细介绍在:https://pypi.org/project/torchvision/

torchvision 主要包括以下几个包:

  • vision.datasets : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的 Dataset 的子类
  • vision.models : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
  • vision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到 tensor ,numpy 数组到 tensor , tensor 到 图像等。
  • vision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个 mini-batch 的图像可以产生一个图像格网。

pytorch 自带的数据集

pytorch 中自带的数据集由两个上层 api 提供,分别是torchvisiontorchtext

其中:

  1. torchvision提供了对图片数据处理相关的 api 和数据
    • 数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的 API 和数据
    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

下面我们以 Mnist 手写数字为例,来看看 pytorch 如何加载其中自带的数据集

使用方法和之前一样:

  1. 准备好 Dataset 实例
  2. 把 dataset 交给 dataloder 打乱顺序,组成 batch

torchversion.datasets

torchversoin.datasets中的数据集类(比如torchvision.datasets.MNIST),都是继承自Dataset

意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例

但是 MNIST API 中的参数需要注意一下:

torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)

  1. root参数表示数据存放的位置
  2. train:bool 类型,表示是使用训练集的数据还是测试集的数据
  3. download:bool 类型,表示是否需要下载数据到 root 目录
  4. transform:实现的对图片的处理函数

MNIST 数据集的介绍

数据集的原始地址:http://yann.lecun.com/exdb/mnist/

MNIST 是由Yann LeCun等人提供的免费的图像识别的数据集,其中包括 60000 个训练样本和 10000 个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28

执行代码,下载数据,观察数据类型:

1
2
3
import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])

下载的数据如下:

Untitled

代码输出结果如下:

1
2
3
4
5
6
7
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
(<PIL.Image.Image image mode=L size=28x28 at 0x18D303B9C18>, tensor(5))

可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值

返回值的第 0 个为 Image 类型,可以调用 show() 方法打开,发现为手写数字 5

1
2
3
4
5
import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])
img = dataset[0][0]
img.show() #打开图片

图片如下:

由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到。