shuffle = False时,不打乱数据顺序; shuffle = True,随机打乱
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py.File('train.h5', 'w'); data1 = np.array([[1,2,3], [2,5,6], [3,5,6], [4,5,6]]) data2 = np.array([[1,1,1], [1,2,6], [1,3,6], [1,4,6]]) h5f.create_dataset(str('data'), data=data1) h5f.create_dataset(str('label'), data=data2) class Dataset(Dataset): def __init__(self): h5f = h5py.File('train.h5', 'r') self.data = h5f['data'] self.label = h5f['label'] def __getitem__(self, index): data = torch.from_numpy(self.data[index]) label = torch.from_numpy(self.label[index]) return data, label def __len__(self): assert self.data.shape[0] == self.label.shape[0], "wrong data length" return self.data.shape[0] dataset_train = Dataset() loader_train = DataLoader(dataset=dataset_train, batch_size=2, shuffle = True) for i, data in enumerate(loader_train): train_data, label = data print(train_data) |