Pytorch划分数据集的方法:torch.utils.data.Subset

时间:2019-11-04 12:25:12   收藏:0   阅读:1256

 

torch.utils.data

 

Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler

torch的这个文件包含了一些关于数据集处理的类:

示例


下面Pytorch提供的划分数据集的方法以示例的方式给出:

SubsetRandomSampler



dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
    # Train:   
    for batch_index, (faces, labels) in enumerate(train_loader):
      

 

 

random_split

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

 

 

参考:

https://www.cnblogs.com/marsggbo/p/10496696.html

https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets

https://likewind.top/2019/02/01/Pytorch-dataprocess/

https://blog.csdn.net/xholes/article/details/81410834

原文:https://www.cnblogs.com/Bella2017/p/11791216.html

评论(0
© 2014 bubuko.com 版权所有 - 联系我们:wmxa8@hotmail.com
打开技术之扣,分享程序人生!