以下回答来自 ChatGPT

在深度学习中,Dataset 和 DataLoader 是两个常用的概念,用于将数据加载到模型中进行训练。Dataset 用于加载数据,并将数据转换为模型可以使用的形式;DataLoader 用于批量加载数据,并进行数据增强、shuffle 等操作。下面分别介绍一下它们的作用。

Dataset Dataset 是 PyTorch 中用于加载数据的类,它继承自 torch. utils. data. Dataset。 在自定义 Dataset 时,需要实现__len____getitem__方法。其中__len__方法用于返回数据集的长度,__getitem__方法用于返回数据集中指定索引的数据。 在自定义 Dataset 时,可以根据实际需求对数据进行预处理、转换格式等操作,以适应模型的输入要求。Dataset 的一个常见应用是对图像进行数据增强,包括随机裁剪、随机翻转、调整亮度对比度等操作。可以使用 torchvision. transforms 中的函数实现数据增强。

DataLoader DataLoader 是 PyTorch 中用于批量加载数据的类,它继承自 torch. utils. data. DataLoader。使用 DataLoader 可以将 Dataset 中的数据进行批量加载,并进行 shuffle、数据增强等操作。DataLoader 还可以设置多线程加载数据,以加快数据加载速度。 在使用 DataLoader 时,需要指定 batch_sizeshuffle 等参数。batch_size 指定每个 batch 中的样本数量,shuffle 指定是否随机打乱数据。可以使用 DataLoader 加载自定义的 Dataset,也可以使用 torchvision. datasets 中提供的常见数据集,如 MNIST、CIFAR 等。 通常情况下,我们使用 DataLoader 来将数据分成若干个 batch,然后将每个 batch 的数据输入模型中进行训练,以提高训练效率。

dataset 是数据集 dataloader 是数据加载

前言

batch_size

下面例子直观感受 batch_size 的作用:

1
2
3
4
5
6
7
list = [1,2,3,4,5,6,7]     # 所有数据,dataset  

batch_size = 2

for i in range(0,len(list),batch_size): # 数据加载过程,dataloader
batch_data = list[i:i+batch_size]
print(batch_data)

输出:
[1, 2]
[3, 4]
[5, 6]
[7]

epoch、shuffle

epoch 表示轮次,shuffle 表示是否打乱数据,下面给出例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random  

list = [1,2,3,4,5,6,7] #所有数据,dataset

batch_size = 2
epoch = 3 # 轮次
shuffle = True # 是否打乱

for e in range(epoch):
if shuffle:
random.shuffle(list)
print(list)
for i in range(0,len(list),batch_size): #数据加载过程,dataloader
batch_data = list[i:i+batch_size]
print(batch_data)

输出:
[5, 7, 6, 4, 3, 1, 2]
[5, 7]
[6, 4]
[3, 1]
[2]
[7, 6, 3, 4, 5, 2, 1]
[7, 6]
[3, 4]
[5, 2]
[1]
[1, 7, 4, 6, 3, 5, 2]
[1, 7]
[4, 6]
[3, 5]
[2]

复现

DataSet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import random  

class MyDataset:
def __init__(self,all_data,batch_size,shuffle=True):
self.all_data = all_data
self.batch_size = batch_size
self.shuffle = shuffle

self.cursor = 0 # 左边界

# 魔术方法(特定条件下自动触发的函数)

def __iter__(self): # 只有第一次会触发,所以需要返回一个具有__next__的对象
print("======hello _iter_======")
if self.shuffle:
random.shuffle(self.all_data)
self.cursor = 0 # 游标重置
return self

def __next__(self):
if self.cursor >= len(self.all_data):
raise StopIteration
batch_data = self.all_data[self.cursor : self.cursor+self.batch_size]
self.cursor += self.batch_size
return batch_data

if __name__ == "__main__":
all_data = [1,2,3,4,5,6,7]
batch_size = 2
epoch = 3
shuffle = True

dataset = MyDataset(all_data,batch_size,shuffle)

for e in range(epoch):
# 把一个对象放在for上时,会自动调用这个对象的__iter__,但只会在第一个循环触发
for batch_data in dataset:
print(batch_data)

输出:
======hello _iter_======
[5, 1]
[6, 3]
[2, 7]
[4]
======hello _iter_======
[6, 3]
[5, 4]
[7, 2]
[1]
======hello _iter_======
[1, 2]
[6, 5]
[3, 7]
[4]

DataLoader

dataloader 主要用于加载数据,所以上面写的 dataset 的__next__的功能实际上时 dataloader 做的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import random  

class MyDataSet:
def __init__(self, all_data, batch_size, shuffle=True):
self.all_data = all_data
self.batch_size = batch_size
self.shuffle = shuffle

def __iter__(self): # 只有第一次会触发
print("======hello _iter_======")
if self.shuffle:
random.shuffle(self.all_data) # random在数据量很多时复杂度很高
return DataLoader(self)


class DataLoader:
def __init__(self, dataset):
self.dataset = dataset
self.cursor = 0 # 左边界

def __next__(self):
if self.cursor >= len(self.dataset.all_data):
raise StopIteration
batch_data = self.dataset.all_data[self.cursor: self.cursor + self.dataset.batch_size]
self.cursor += self.dataset.batch_size
return batch_data


if __name__ == "__main__":
all_data = [1, 2, 3, 4, 5, 6, 7]
batch_size = 2
epoch = 3
shuffle = True

dataset = MyDataSet(all_data, batch_size, shuffle)

for e in range(epoch):
for batch_data in dataset:
print(batch_data)

输出:
======hello _iter_======
[6, 4]
[3, 5]
[2, 7]
[1]
======hello _iter_======
[7, 1]
[5, 2]
[6, 3]
[4]
======hello _iter_======
[5, 6]
[4, 2]
[1, 7]
[3]

这里 shuffle 打乱数据是利用 random 进行的,在数据较少时可以,但是面对深度学习的超大数据时有着很高的复杂度。解决的思路是:不打乱数据,打乱索引。

优化

shuffle 利用打乱索引实现:在 22 行中先构建了一个索引 list,再利用 np. random. shuffle 对这个 list 进行打乱,循环时只需要每次再这个索引 list 中选取 batch_size 个数据,再在原来的 all_data 中找到对应数据即可。

batch_data = self. dataset. all_data[ind] 可以看出,在 python 中,下标可以是一个 list,返回的结果也是一个 list。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np  

class MyDataSet:
def __init__(self,all_data,batch_size,shuffle=True):
self.all_data = all_data
self.batch_size = batch_size
self.shuffle = shuffle

def __iter__(self): # 只有第一次会触发
print("======hello _iter_======")
return DataLoader(self)

def __len__(self):
return len(self.all_data)


class DataLoader:
def __init__(self,dataset):
self.dataset = dataset
self.cursor = 0 # 左边界

self.index = [i for i in range(len(self.dataset))]
if self.dataset.shuffle == True:
np.random.shuffle(self.index)

def __next__(self):
if self.cursor >= len(self.dataset.all_data):
raise StopIteration
ind = self.index[self.cursor : self.cursor+self.dataset.batch_size] # ind是个列表
batch_data = self.dataset.all_data[ind] # batch_data也是list
self.cursor += self.dataset.batch_size
return batch_data


if __name__ == "__main__":
all_data = np.array([1,2,3,4,5,6,7])
batch_size = 2
epoch = 3
shuffle = True

dataset = MyDataSet(all_data,batch_size,shuffle)

for e in range(epoch):
for batch_data in dataset:
print(batch_data)

输出:
======hello _iter_======
[3 1]
[5 7]
[4 2]
[6]
======hello _iter_======
[1 3]
[7 5]
[2 6]
[4]
======hello _iter_======
[2 4]
[1 6]
[7 5]
[3]

相关链接

bilibili
GitHub