https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
https://blog.csdn.net/l8947943/article/details/103733473
1. 我们需要加载自己的数据集,使用Dataset和DataLoader
Dataset :是被封装进DataLoader里,实现该方法封装自己的数据和标签。DataLoader :被封装入DataLoader迭代器里,实现该方法达到数据的划分。
2.Dataset
主要继承该方法必须实现两个方法:
_getitem_() _len_()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import torch import numpy as np # 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法 class GetLoader(torch.utils.data.Dataset): # 初始化函数,得到数据 def __init__(self, data_root, data_label): self.data = data_root self.label = data_label # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回 def __getitem__(self, index): data = self.data[index] labels = self.label[index] return data, labels # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼 def __len__(self): return len(self.data) # 随机生成数据,大小为10 * 20列 source_data = np.random.rand(10, 20) # 随机生成标签,大小为10 * 1列 source_label = np.random.randint(0,2,(10, 1)) # 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels torch_data = GetLoader(source_data, source_label) |
3.DataLoader
提供对
1 | torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers) |
数含义如下:
- dataset: 加载torch.utils.data.Dataset对象数据
- batch_size: 每个batch的大小
- shuffle:是否对数据进行打乱
- drop_last:是否对无法整除的最后一个datasize进行丢弃
- um_workers:表示加载的时候子进程数,一般GPU使用
因此,在实现过程中我们测试如下(紧跟上述用例):
1 2 3 4 | from torch.utils.data import DataLoader # 读取数据 datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2) |
此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。
4.查看数据
我们可以通过迭代器
1 2 3 | for i, data in enumerate(datas): # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels print("第 {} 个Batch \n{}".format(i, data)) |
5.使用自己保存的“npy”数据集进行加载
定义一个继承dataset的类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import numpy as np from torch.utils.data.dataset import Dataset import torch # 定义CustomDataset类,继承Dataset方法,并重写__getitem__()和__len__()方法 class CustomDataset(torch.utils.data.Dataset): # 初始化函数,得到数据 def __init__(self, pathData, pathLabel): self.data = np.load(pathData) # 传入了dataset X的路径,并使用np.load进行加载数据 self.label = np.load(pathLabel) # 传入了label Y的路径 # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回 def __getitem__(self, index): data = self.data[index] labels = self.label[index] return data, labels # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼 def __len__(self): return len(self.data) |
加载数据
我自己的数据集格式:

1 2 3 4 5 6 7 8 9 10 11 | from torch.utils.data import DataLoader from CustomDataset import CustomDataset pathX = './datasetXPro.npy' pathY = './datasetYPro.npy' torch_data = CustomDataset(pathX, pathY) # 读取数据 datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2) for i, data in enumerate(datas): # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels print("第 {} 个Batch \n{}".format(i, data)) |