torch.utils.data
TensorDataset
简介
顾名思义,torch.utils.data 中的 TensorDataset 基于一系列张量构建数据集。这些张量的形状可以不尽相同,但第一个维度必须具有相同大小,这是为了保证在使用 DataLoader 时可以正常地返回一个批量的数据。
源码解读
以下是 TensorDataset 的源码:
1 | |
*tensors 告诉我们实例化 TensorDataset 时传入的是一系列张量,即:
1 | |
随后的 assert 是用来确保传入的这些张量中,每个张量在第一个维度的大小都等于第一个张量在第一个维度的大小,即要求所有张量在第一个维度的大小都相同。
__getitem__ 方法返回的结果等价于
1 | |
从这行代码可以看出,如果$n$ 张量在第一个维度的大小不完全相同,则必然会有一个张量出现 IndexError。确保第一个维度大小相同也是为了之后传入 DataLoader 中能够正常地以一个批量的形式加载。
__len__ 就不用多说了,因为所有张量的第一个维度大小都相同,所以直接返回传入的第一个张量在第一个维度的大小即可。
📌
TensorDataset将张量的第一个维度视为数据集大小的维度,数据集在传入 DataLoader 后,该维度也是 batch_size 所在的维度
通过例子进一步理解
假设当前目录下存放一个 data.csv 文件,其中的每一行的后六个数字代表样本对应的特征向量,第一个数字代表该样本对应的标签。
1 | |
接下来我们分别用普通方法和 TensorDataset 方法来构建数据集。
普通方法:
1 | |
TensorDataset 方法
1 | |
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Hongwei Zhao's Blog!



