开发者

PyTorch实现线性回归详细过程

目录
  • 一、实现步骤
    • 1、准备数据
    • 2、设计模型
    • 3、构造损失函数和优化器
    • 4、训练过程
    • 5、结果展示
  • 二、参考文献

    一、实现步骤

    1、准备数据

    x_data = torch.tensor([[1.0],[2.0],[3.0]])
    y_data = torch.tensor([[2.0],[4.0],[6.0]])

    2、设计模型

    class LinearModel(torch.nn.Module):
      def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
       
      def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
       
    model = LinearModel() 

    3、构造损失函数和优化器

    criterion = torch.nn.MSELoss(reduction='sum')
    optimhttp://www.cppcns.comizer = torch.optim.SGD(model.parameters(),lr=0.01)

    4、训练过程

    epoch_list = []
    loss_list = []
    w_list = []
    b_list = []
    for epoch in range(1000):
      y_pred = model(x_data)      # 计算预测值
      loss = criterion(y_pred, y_data) # 计算损失
      print(epoch,loss)
     
      epoch_list.append(epoch)
      loss_list.append(loss.data.item())
      w_list.append(model.linear.weight.item())
      b_list.append(model.linear.bias.item())
     
      optimizer.zero_grad()  # 梯度归零
      loss.backward()     # 反向传播
      optimizer.step()    # 更新

    5、结果展示

    展示最终的权重和偏置:

    # 输出权重和偏置
    print('w = ',model.linear.weight.item())
    print('b = ',model.linear.bias.item())

    结果为:

    w =  1.9998501539230347

    b =  0.0003405189490877092

    模型测试:

    # 测试模型
    x_test = torch.tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ',y_test.data)
    
    y_pred = tensor([[7.9997]])

    分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:

    # 二维曲线图
    plt.phttp://www.cppcns.comlot(epoch_list,loss_list,'b')
    plt编程客栈.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()
    
    # 三维散点图
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(w_list,b_list,loss_list,c='r')
    #设置坐MfJdEM标轴
    ax.set_xlabel('weight')
    ax.set_ylabel('bias')
    ax.set_zlabel('loss')
    plt.show()

    结果如下图所示:

    PyTorch实现线性回归详细过程

    PyTorch实现线性回归详细过程

     到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

    二、参考文献

    • [1] https://MfJdEMwww.bilibili.com/video/BV1Y7411d7Ys?p=5

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜