当前位置: 首页 > news >正文

001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):for param in params:param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

相关文章:

001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R10002 …...

Relations Prediction for Knowledge Graph Completion using Large Language Models

文章目录 题目摘要简介相关工作方法论实验结论局限性未来工作 题目 使用大型语言模型进行知识图谱补全的关系预测 论文地址:https://arxiv.org/pdf/2405.02738 项目地址: https://github.com/yao8839836/kg-llm 摘要 知识图谱已被广泛用于以结构化格式表…...

2024年中国研究生数学建模竞赛D题思路代码分析——大数据驱动的地理综合问题

地理系统是自然、人文多要素综合作用的复杂巨系统[1-2],地理学家常用地理综合的方式对地理系统进行主导特征的表达[3]。如以三大阶梯概括中国的地形特征,以秦岭—淮河一线和其它地理区划的方式揭示中国气温、降水、植被、土壤及生态环境在水平和垂直方向…...

全国31省对外开放程度、经济发展水平、政府干预程度指标数据(2000-2022年)

旨在分析2000-2022年间中国31个省份的对外开放程度、经济发展水平和政府干预程度,探讨其背后的动因与影响。 2000年-2022年 全国31省对外开放程度、经济发展水平、政府干预程度指标数据https://download.csdn.net/download/2401_84585615/89478612 数据概览 对外…...

计算机网络传输层---课后综合题

线路:TCP报文下放到物理层传输。 TCP报文段中,“序号”长度为32bit,为了让序列号不会循环,则最多能传输2^32B的数据,则最多能传输:2^32/1500B个报文 结果: 吞吐率一个周期内传输的数据/周期时间…...

【homebrew安装】踩坑爬坑教程

homebrew官网,有安装教程提示,但是在实际安装时,由于待下载的包的尺寸过大,本地git缓存尺寸、超时时间的限制,会报如下错误: error: RPC failed; curl 92 HTTP/2 stream 5 was not closed cleanly&#xf…...

反游戏学(Reludology):概念、历史、现状与展望?(豆包AI版)

李升伟 以下是关于“反游戏学(Reludology):概念、历史、现状与展望”的综述: 一、概念 反游戏学(Reludology)是一个相对较新且不太常见的概念,目前尚未有统一明确的定义。一般来说&#xf…...

【C/C++语言系列】实现单例模式

1.单例模式概念 定义:单例模式是一种常见的设计模式,它可以保证系统中一个类只有一个实例,而且该实例易于外界访问(一个类一个对象,共享这个对象)。 条件: 只有1个对象易于外界访问共享这个对…...

A. Make All Equal

time limit per test 1 second memory limit per test 256 megabytes You are given a cyclic array a1,a2,…,ana1,a2,…,an. You can perform the following operation on aa at most n−1n−1 times: Let mm be the current size of aa, you can choose any two adjac…...

业务安全治理

业务安全治理 1.账号安全撞库账户盗用 2.爬虫与反爬虫3.API网关防护4.钓鱼与反制钓鱼发现钓鱼处置 5.大数据风控风控介绍 1.账号安全 撞库 撞库分为垂直撞库和水平撞库两种,垂直撞库是对一个账号使用多个不同的密码进行尝试,可以理解为暴力破解&#x…...

HelpLook VS GitBook,在线文档管理工具对比

在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…...

docker面经

docker面经在线链接 docker面经在线链接🔗: (https://h03yz7idw7.feishu.cn/wiki/N3CVwO3kMifLypkJqnic9wNynKh)...

Python 中的 Kombu 类库

Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…...

safepoint是什么?有什么用?

在JVM中,safepoint(安全点)是一个非常重要的概念,特别是在垃圾回收(GC)和其他需要暂停所有应用线程的操作中。 什么是safepoint Safepoint是JVM执行过程中一个特定的位置,在这个位置上&#x…...

axios相关知识点

一、基本概念 1、基于Promise:Axios通过Promise实现异步请求,避免了传统回调函数导致的“回调地狱”问题,使得代码更加清晰和易于维护。 2、跨平台:Axios既可以在浏览器中运行,也可以在Node.js环境中使用,为前后端开…...

LeetCode 面试经典150题 67.二进制求和

415.字符串相加 思路一模一样 题目:给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 eg: 输入a“1010” b“1011” 输出“10101” 思路:从右开始遍历两个字符串,因为右边是低位先运算。如果…...

Dell PowerEdge 网络恢复笔记

我有一台Dell的PowerEdge服务器,之前安装了Ubuntu 20 桌面版。突然有一天不能开机了。 故障排查 Disk Error 首先是看一下机器的正面,有一个非常小的液晶显示器,只能显示一排字。 上面显示Disk Error,然后看挂载的硬盘仓&#…...

Java面试——集合篇

1.Java中常用的容器有哪些? 容器主要包括 Collection 和 Map 两种,Collection 存储着对象的集合,而 Map 存储着键值对(两个对象)的映射表。 如图: 面试官追问:说说集合有哪些类及他们各自的区别和特点? S…...

算法【双向广搜】

双向广搜常见用途 1:小优化。bfs的剪枝策略,分两侧展开分支,哪侧数量少就从哪侧展开。 2:用于解决特征很明显的一类问题。特征:全量样本不允许递归完全展开,但是半量样本可以完全展开。过程:把…...

javascript检测数据类型的方法

1. typeof 运算符 typeof是一个用来检测变量的基本数据类型的运算符。它可以返回以下几种类型的字符串:“undefined”、“boolean”、“number”、“string”、“object”、“function” 和 “symbol”(ES6)。需要注意的是,对于 n…...

生信初学者教程(五):R语言基础

文章目录 数据类型整型逻辑型字符型日期型数值型复杂数数据结构向量矩阵数组列表因子数据框ts特殊值缺失值 (NA)无穷大 (Inf)非数字 (NaN)安装R包学习材料R语言是一种用于统计计算和图形展示的编程语言和软件环境,广泛应用于数据分析、统计建模和数据可视化。1991年:R语言的最…...

深度学习计算

一、层和块 块可以描述单个层、多个层组成的组件或整个模型。 通过定义块,组装块,可以实现复杂的神经网络。 一个块可以由多个class组成。 其实就是 自己定义神经网络net,自己定义层的顺序和具体的init、 forward函数。 层和块的顺序由sequen…...

Hexo博客私有部署Twikoo评论系统并迁移评论记录(自定义邮件回复模板)

部署 之前一直使用的artalk,现在想改用Twikoo,采用私有部署的方式。 私有部署 (Docker) 端口可以根据实际情况进行修改 docker run --name twikoo -e TWIKOO_THROTTLE1000 -p 8100:8100 -v ${PWD}/data:/app/data -e TWIKOO_PORT8100 -d imaegoo/twi…...

Vue.js 与 Flask/Django 后端配合:构建现代 Web 应用的最佳实践

Vue.js 与 Flask/Django 后端配合:构建现代 Web 应用的最佳实践 在现代 Web 开发中,前后端分离的架构已经成为主流。Vue.js 作为一个渐进式 JavaScript 框架,因其灵活性和易用性而广受欢迎。而 Flask 和 Django 则是 Python 生态中两个非常流…...

【笔记】自动驾驶预测与决策规划_Part3_路径与轨迹规划

文章目录 0. 前言1. 基于搜索的路径规划1.1 A* 算法1.2 Hybrid A* 算法 2. 基于采样的路径规划2.1 Frent Frame方法2.2 Cartesian →Frent 1D ( x , y ) (x, y) (x,y) —> ( s , l ) (s, l) (s,l)2.3 Cartesian →Frent 3D2.4 贝尔曼Bellman最优性原理2.5 高速轨迹采样——…...

Shiro-721—漏洞分析(CVE-2019-12422)

文章目录 Padding Oracle Attack 原理PKCS5填充怎么爆破攻击 漏洞原理源码分析漏洞复现 本文基于shiro550漏洞基础上分析,建议先看上期内容: https://blog.csdn.net/weixin_60521036/article/details/142373353 Padding Oracle Attack 原理 网上看了很多…...

【Python语言初识(一)】

一、python简史 1.1、python的历史 1989年圣诞节:Guido von Rossum开始写Python语言的编译器。1991年2月:第一个Python编译器(同时也是解释器)诞生,它是用C语言实现的(后面),可以调…...

Python 中的方法解析顺序(MRO)

在 Python 中,MRO(Method Resolution Order,方法解析顺序)是指类继承体系中,Python 如何确定在调用方法时的解析顺序。MRO 决定了在多继承环境下,Python 如何寻找方法或属性,即它会根据一定规则…...

MySQL表的内外连接

内连接 其实就是from 两个表 把笛卡尔积的表 再用where 进行条件筛选 ——之前我们写的多表查询就是内连接 基本格式: 外链接 没有向内连接那样真的把两个表连接形式一个表显示,而只是建立关系 外连接分为左链接和右链接 左链接 联合查询时候&#…...

系统架构设计师:软件架构的演化和维护

简简单单 Online zuozuo: 简简单单 Online zuozuo 简简单单 Online zuozuo 简简单单 Online zuozuo 简简单单 Online zuozuo :本心、输入输出、结果 简简单单 Online zuozuo : 文章目录 系统架构设计师:软件架构的演化和维护前言软件架构演化的重要性面向对象的软件架构演…...