lee-romantic 's Blog
Everything is OK!
Toggle navigation
lee-romantic 's Blog
主页
About Me
归档
标签
pytorch学习:自定义Datasets以及DataLoader
2018-11-18 17:10:55
562
0
0
lee-romantic
**(一)自定义dataset** **什么是Datasets:** 在输入流水线中,我们看到准备数据的代码是这么写的`data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)`。`datasets.CIFAR10`就是一个`Datasets`子类,`data`是这个类的一个实例。 **为什么要定义Datasets**: PyTorch提供了一个工具函数`torch.utils.data.DataLoader`。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。 **如何自定义Datasets** 下面是一个自定义Datasets的框架 ``` class CustomDataset(data.Dataset):#需要继承data.Dataset,也可以继承object,自定义dataset最主要的是需要重定义__getitem__和__len__ def __init__(self): # TODO # 1. Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #这里需要注意的是,第一步:read one data,是一个data pass def __len__(self): # You should change 0 to the total size of your dataset. return 0 ``` 更具体的一个例子: ``` # class Dataset(object): class Dataset(torch.utils.data.Dataset): # 与上面的方式均可以运行,因为最主要是需要有__getitem__() 与__len__() def __init__(self, x0, x1, label): self.size = label.shape[0] self.x0 = torch.from_numpy(x0) self.x1 = torch.from_numpy(x1) self.label = torch.from_numpy(label) def __getitem__(self, index): return (self.x0[index], self.x1[index], self.label[index]) def __len__(self): return self.size ``` **(二)pytorch加载数据** pytorch读取训练集需要使用到2个类: (1)`torch.utils.data.Dataset` (2)`torch.utils.data.DataLoader` `torch.utils.data`主要包括以下三个类: 1.**class torch.utils.data.Dataset** 作用: (1) 创建数据集,有`__getitem__(self, index)`函数来根据索引序号获取图片和标签, 有`__len__(self)`函数来获取数据集的长度. 其他的数据集类必须是`torch.utils.data.Dataset`的子类,比如说`torchvision.ImageFolder`. 2. **class torch.utils.data.sampler.Sampler(data_source)** 参数: `data_source (Dataset)` – `dataset to sample from` 作用: 创建一个采样器, `class torch.utils.data.sampler.Sampler`是所有的`Sampler`的基类, 其中,`__iter__(self)`函数来获取一个迭代器,对数据集中元素的索引进行迭代,`__len(self)__`方法返回迭代器中包含元素的长度. 3.**class torch.utils.data.DataLoader**(`dataset, batch_size=1, shuffle=False, `sampler=None, batch_sampler=None, `num_workers=0,` collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None) 参数: >* `dataset (Dataset)`: 加载数据的数据集 * `batch_size (int, optional)`: 每批加载多少个样本 * `shuffle (bool, optional)`: 设置为“真”时,在每个epoch对数据打乱.(默认:False) * `sampler (Sampler, optional)`: 定义从数据集中提取样本的策略,返回一个样本 * `batch_sampler (Sampler, optional)`: like sampler, but returns a batch of indices at a time 返回一批样本. 与atch_size, shuffle, sampler和 drop_last互斥. *` num_workers` (int, optional): 用于加载数据的子进程数。0表示数据将在主进程中加载。(默认:0) * collate_fn (callable, optional): 合并样本列表以形成一个 mini-batch. # callable可调用对象 * pin_memory (bool, optional): 如果为 True, 数据加载器会将张量复制到 CUDA 固定内存中,然后再返回它们. * drop_last (bool, optional): 设定为 True 如果数据集大小不能被批量大小整除的时候, 将丢掉最后一个不完整的batch,(默认:False). * timeout (numeric, optional): 如果为正值,则为从工作人员收集批次的超时值。应始终是非负的。(默认:0) * worker_init_fn (callable, optional): If not None, this will be called on each worker subproc *具体例子参考地址:* https://blog.csdn.net/u012436149/article/details/69061711 和 https://blog.csdn.net/tsq292978891/article/details/79414512
上一篇:
MFC中check_box的控制方法
下一篇:
np中where()用法
0
赞
562 人读过
新浪微博
微信
腾讯微博
QQ空间
人人网
提交评论
立即登录
, 发表评论.
没有帐号?
立即注册
0
条评论
More...
文档导航
没有帐号? 立即注册