加载数据集(Dataset and Dataloader)

dataset主要是用于构造数据集(支持索引),dataloader可以拿出一个mini-batch供我们快速使用。

一:上一节知识

一下是我们上一节提到的糖尿病数据集,其中在提到数据加载的时候,我们没有使用mini-batch的方法。

通常梯度下降有几种选择:

①Batch:选择全部数据传入,优点就是可以提升计算速度

②随机梯度下降:选择一个样本传入,优点是可以克服鞍点的问题

③mini-Batch:结合上述两种方法的优点

对于使用mini-batch需要明白的概念:Epoch,Batch-size,Iterations

Epoch:所有训练样本都经过一次的forward和一次的backward,称完成一次的epoch

Batch-size:指每次训练时所用的样本数量(forward,backward,updata)

Iterations:一共执行了多少个Batch-size,就是内层循环执行的次数

二:dataload

在做小批量的训练时,需要确定几个参数。

batch-size:每次训练时所用的样本数量

为了保证数据集的随机性,引入了shuffle=True 表示每次生成的mini-batch里的数据集都是打乱的。

得到数据集,数据集需要满足支持索引,知道长度。对其进行shuffle打乱顺序,接着进行分组loader,设置batch_size=2,所以每组有2个数值。

代码如下:

import torch
from torch.utils.data import Dataset #Dataset抽象类,只能继承使用
from torch.utils.data import DataLoader#用于加载数据,入batch_size,shuffle等,直接实例化使用

class DiabetesDataset(Dataset):#继承Dataset
    def __init__(self):
        pass
    
    def __getitem__(self,index):#实例化后可以支持下标操作,通过index把数据拿出来
        pass
    
    def __len__(self):#得到数据的条数
        pass

dataset = DiabetesDataset()
#dataset传数据集  batch_size小批量大小 shuffle是否打乱 num_works并行(此处为两个进程并行)
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

使用多进程在后面的train_loader时,可能会报错,需要将后面使用到train_loader的语句封装起来。后面会有详细描述。

三:Dataset的实现

我们看到其实上面的代码关于初始化,索引,长度的函数都还没写,下面关于这三个部分进行详解:

代码:

import torch
import numpy as np
from torch.utils.data import Dataset #Dataset抽象类,只能继承使用
from torch.utils.data import DataLoader#用于加载数据,入batch_size,shuffle等,直接实例化使用

class DiabetesDataset(Dataset):#继承Dataset
    def __init__(self,filepath):#filepath是文件地址
        xy = np.loadtxt(filepath, delimiter=',',dtype=np.float32)#糖尿病人的相关因素
        self.len = xy.shape[0] #我们知道xy是一共矩阵nx9的矩阵,shape[0]得到整个数据集的个数
        self.x_data = torch.from_numpy(xy[:,:-1])#所有行,从第一列开始到最后一列(最后一列不要)
        self.y_data = torch.from_numpy(xy[:,[-1]])#所有行,最后一列([]是为了保证最后一列是一个矩阵,否则就是一个向量)
    
    def __getitem__(self,index):#实例化后可以支持下标操作,通过index把数据拿出来
        return self.x_data[index],self.y_data[index]
    
    def __len__(self):#得到数据的条数
        return self.len

dataset = DiabetesDataset('./dataset/diabetes.csv.gz')
#dataset传数据集  batch_size小批量大小 shuffle是否打乱 num_works并行(此处为两个进程并行)
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

 

for epoch in range(100):#全部的所有数据运行100遍
    for i,data in enumerate(train_loader,0):#enumerate为了获取当前是第几次迭代 i得到的是index的数据,data对应的是(x,y)
        #1 数据准备
        inputs, labels = data #将x,y放入inputs和labels loader会自动的将data内的数变成张量
        #print (input,labels)
        #2 forward
        y_pred= model(inputs)
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.item())
        #3 backward
        optimizer.zero_grad()
        loss.backward()
        #4 Updata
        optimizer.step()

其他部分和上一讲相同,整体的步骤如下:

四:练习 

Titanic - Machine Learning from Disaster | Kaggle 网站使用他给的数据进行联系。

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-20 13:34:02       169 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-20 13:34:02       185 阅读
  3. 在Django里面运行非项目文件

    2024-07-20 13:34:02       155 阅读
  4. Python语言-面向对象

    2024-07-20 13:34:02       169 阅读

热门阅读

  1. windows上安装Apache

    2024-07-20 13:34:02       27 阅读
  2. 信息查询_社工

    2024-07-20 13:34:02       26 阅读
  3. Clickhouse 物化视图-optimize无效

    2024-07-20 13:34:02       34 阅读
  4. 07.16_111期_linux_网络通信

    2024-07-20 13:34:02       30 阅读
  5. 我为什么要使用Vim编辑器?

    2024-07-20 13:34:02       25 阅读
  6. 微服务概念篇-服务提供者/服务消费者

    2024-07-20 13:34:02       30 阅读
  7. 后端配置了相关字段后的前端跨域处理

    2024-07-20 13:34:02       32 阅读
  8. IP地址:由电脑还是网线决定?

    2024-07-20 13:34:02       29 阅读
  9. 【AI工具基础】—B树(B-tree)

    2024-07-20 13:34:02       30 阅读