pytorch日积月累5-pytorch数据读取机制

DataLoader和DataSet

  • torch.utils.data.DataLoader
  • 功能:构建可迭代的数据装载器
1
2
3
4
5
6
7
#构建可迭代的数据装载器
torch.utils.data.DataLoader( dataset,#Dataset类,决定数据从哪读取以及如何读取
batch_size=1,#批大小
shuffle=False,#每个epoch是否乱序
num_workers=0,#是否多进程读取数据
drop_last=False,#当样本数不能被batchsize整除时,是否舍弃最后一波数据
)
  • epoch: 所有训练样本都已输入到模型中,称为一个epoch
  • iteration:一批样本输入到模型中,称之为一个iteration
  • batchsize:批大小,决定一个epoch有多少个iteration

【举例】

  • 样本总数:80, Batchsize:8 ,则1 Epoch = 10 Iteration
  • 样本总数:87, Batchsize:8
    • 1 Epoch = 10 Iteration when drop_last = True
    • 1 Epoch = 11 Iteration when drop_last = False

torch.utils.data.Dataset

  • 功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
  • getitem :接收一个索引,返回一个样本
1
2
3
4
5
6
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

数据读取:

  • 读哪些数据——sampler输出的index
  • 从哪里读数据——DataSet中的data_dir
  • 怎么读数据——Dataset中的getitem
1
2
train_data = RDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batach_size=BATCH_SIZE, shuffle=True)

数据读取的过程:

image-20200820203241733