以下回答来自 ChatGPT
在深度学习中,Dataset 和 DataLoader 是两个常用的概念,用于将数据加载到模型中进行训练。Dataset 用于加载数据,并将数据转换为模型可以使用的形式;DataLoader 用于批量加载数据,并进行数据增强、shuffle 等操作。下面分别介绍一下它们的作用。
Dataset Dataset 是 PyTorch 中用于加载数据的类,它继承自 torch. utils. data. Dataset。 在自定义 Dataset 时,需要实现__len__和__getitem__方法。其中__len__方法用于返回数据集的长度,__getitem__方法用于返回数据集中指定索引的数据。 在自定义 Dataset 时,可以根据实际需求对数据进行预处理、转换格式等操作,以适应模型的输入要求。Dataset 的一个常见应用是对图像进行数据增强,包括随机裁剪、随机翻转、调整亮度对比度等操作。可以使用 torchvision. transforms 中的函数实现数据增强。
DataLoader DataLoader 是 PyTorch 中用于批量加载数据的类,它继承自 torch. utils. data. DataLoader。使用 DataLoader 可以将 Dataset 中的数据进行批量加载,并进行 shuffle、数据增强等操作。DataLoader 还可以设置多线程加载数据,以加快数据加载速度。 在使用 DataLoader 时,需要指定 batch_size、shuffle 等参数。batch_size 指定每个 batch 中的样本数量,shuffle 指定是否随机打乱数据。可以使用 DataLoader 加载自定义的 Dataset,也可以使用 torchvision. datasets 中提供的常见数据集,如 MNIST、CIFAR 等。 通常情况下,我们使用 DataLoader 来将数据分成若干个 batch,然后将每个 batch 的数据输入模型中进行训练,以提高训练效率。
dataset 是数据集 dataloader 是数据加载
前言
batch_size
下面例子直观感受 batch_size 的作用:
1 2 3 4 5 6 7 list = [1 ,2 ,3 ,4 ,5 ,6 ,7 ] # 所有数据,dataset batch_size = 2 for i in range (0 ,len(list) ,batch_size): # 数据加载过程,dataloader batch_data = list[i:i+batch_size] print(batch_data)
输出:
[1, 2]
[3, 4]
[5, 6]
[7]
epoch、shuffle
epoch 表示轮次,shuffle 表示是否打乱数据,下面给出例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import random list = [1 ,2 ,3 ,4 ,5 ,6 ,7 ] #所有数据,dataset batch_size = 2 epoch = 3 # 轮次 shuffle = True # 是否打乱 for e in range (epoch) : if shuffle: random.shuffle(list) print(list) for i in range (0 ,len(list) ,batch_size): #数据加载过程,dataloader batch_data = list[i:i+batch_size] print(batch_data)
输出:
[5, 7, 6, 4, 3, 1, 2]
[5, 7]
[6, 4]
[3, 1]
[2]
[7, 6, 3, 4, 5, 2, 1]
[7, 6]
[3, 4]
[5, 2]
[1]
[1, 7, 4, 6, 3, 5, 2]
[1, 7]
[4, 6]
[3, 5]
[2]
复现
DataSet
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 import random class MyDataset : def __init__ (self,all_data,batch_size,shuffle=True) : self.all_data = all_data self.batch_size = batch_size self.shuffle = shuffle self.cursor = 0 # 左边界 # 魔术方法(特定条件下自动触发的函数) def __iter__ (self) : # 只有第一次会触发,所以需要返回一个具有__next__的对象 print("======hello _iter_======" ) if self.shuffle: random.shuffle(self.all_data) self.cursor = 0 # 游标重置 return self def __next__ (self) : if self.cursor >= len(self.all_data): raise StopIteration batch_data = self.all_data[self.cursor : self.cursor+self.batch_size] self.cursor += self.batch_size return batch_data if __name__ == "__main__" : all_data = [1 ,2 ,3 ,4 ,5 ,6 ,7 ] batch_size = 2 epoch = 3 shuffle = True dataset = MyDataset(all_data,batch_size,shuffle) for e in range (epoch) : # 把一个对象放在for 上时,会自动调用这个对象的__iter__,但只会在第一个循环触发 for batch_data in dataset: print(batch_data)
输出:
======hello _iter_======
[5, 1]
[6, 3]
[2, 7]
[4]
======hello _iter_======
[6, 3]
[5, 4]
[7, 2]
[1]
======hello _iter_======
[1, 2]
[6, 5]
[3, 7]
[4]
DataLoader
dataloader 主要用于加载数据,所以上面写的 dataset 的__next__的功能实际上时 dataloader 做的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import random class MyDataSet : def __init__ (self, all_data, batch_size, shuffle=True) : self.all_data = all_data self.batch_size = batch_size self.shuffle = shuffle def __iter__ (self) : # 只有第一次会触发 print("======hello _iter_======" ) if self.shuffle: random.shuffle(self.all_data) # random在数据量很多时复杂度很高 return DataLoader(self) class DataLoader : def __init__ (self, dataset) : self.dataset = dataset self.cursor = 0 # 左边界 def __next__ (self) : if self.cursor >= len(self.dataset.all_data): raise StopIteration batch_data = self.dataset.all_data[self.cursor: self.cursor + self.dataset.batch_size] self.cursor += self.dataset.batch_size return batch_data if __name__ == "__main__" : all_data = [1 , 2 , 3 , 4 , 5 , 6 , 7 ] batch_size = 2 epoch = 3 shuffle = True dataset = MyDataSet(all_data, batch_size, shuffle) for e in range (epoch) : for batch_data in dataset: print(batch_data)
输出:
======hello _iter_======
[6, 4]
[3, 5]
[2, 7]
[1]
======hello _iter_======
[7, 1]
[5, 2]
[6, 3]
[4]
======hello _iter_======
[5, 6]
[4, 2]
[1, 7]
[3]
这里 shuffle 打乱数据是利用 random 进行的,在数据较少时可以,但是面对深度学习的超大数据时有着很高的复杂度。解决的思路是:不打乱数据,打乱索引。
优化
shuffle 利用打乱索引实现:在 22 行中先构建了一个索引 list,再利用 np. random. shuffle 对这个 list 进行打乱,循环时只需要每次再这个索引 list 中选取 batch_size 个数据,再在原来的 all_data 中找到对应数据即可。
batch_data = self. dataset. all_data[ind] 可以看出,在 python 中,下标可以是一个 list,返回的结果也是一个 list。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 import numpy as np class MyDataSet : def __init__ (self,all_data,batch_size,shuffle=True) : self.all_data = all_data self.batch_size = batch_size self.shuffle = shuffle def __iter__ (self) : # 只有第一次会触发 print("======hello _iter_======" ) return DataLoader(self) def __len__ (self) : return len(self.all_data) class DataLoader : def __init__ (self,dataset) : self.dataset = dataset self.cursor = 0 # 左边界 self.index = [i for i in range (len(self.dataset) )] if self.dataset.shuffle == True: np.random.shuffle(self.index) def __next__ (self) : if self.cursor >= len(self.dataset.all_data): raise StopIteration ind = self.index[self.cursor : self.cursor+self.dataset.batch_size] # ind是个列表 batch_data = self.dataset.all_data[ind] # batch_data也是list self.cursor += self.dataset.batch_size return batch_data if __name__ == "__main__" : all_data = np.array([1 ,2 ,3 ,4 ,5 ,6 ,7 ]) batch_size = 2 epoch = 3 shuffle = True dataset = MyDataSet(all_data,batch_size,shuffle) for e in range (epoch) : for batch_data in dataset: print(batch_data)
输出:
======hello _iter_======
[3 1]
[5 7]
[4 2]
[6]
======hello _iter_======
[1 3]
[7 5]
[2 6]
[4]
======hello _iter_======
[2 4]
[1 6]
[7 5]
[3]
相关链接
bilibili
GitHub