001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归
0. 背景介绍
我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt
1. 生成数据集合(待拟合)
使用python生成待拟合的数据
num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels = w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])
2.数据的分批量处理
def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)
3. 模型构建及训练
3.1 定义模型:
def linreg(X, w, b):return torch.mm(X,w)+b
3.2 定义损失函数
def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2
3.3 定义优化算法
def sgd(params , lr ,batch_size):for param in params:param.data -= lr * param.grad / batch_size
3.4 模型训练
# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))
epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053
基于pytorch的线性模型的实现
- 相关数据和初始化与上面构建相同
- 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net)
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
- 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
- 定义损失函数
loss = nn.MSELoss()
5.定义优化算法
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
- 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )
epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066
相关文章:
001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R10002 …...
Relations Prediction for Knowledge Graph Completion using Large Language Models
文章目录 题目摘要简介相关工作方法论实验结论局限性未来工作 题目 使用大型语言模型进行知识图谱补全的关系预测 论文地址:https://arxiv.org/pdf/2405.02738 项目地址: https://github.com/yao8839836/kg-llm 摘要 知识图谱已被广泛用于以结构化格式表…...
2024年中国研究生数学建模竞赛D题思路代码分析——大数据驱动的地理综合问题
地理系统是自然、人文多要素综合作用的复杂巨系统[1-2],地理学家常用地理综合的方式对地理系统进行主导特征的表达[3]。如以三大阶梯概括中国的地形特征,以秦岭—淮河一线和其它地理区划的方式揭示中国气温、降水、植被、土壤及生态环境在水平和垂直方向…...
全国31省对外开放程度、经济发展水平、政府干预程度指标数据(2000-2022年)
旨在分析2000-2022年间中国31个省份的对外开放程度、经济发展水平和政府干预程度,探讨其背后的动因与影响。 2000年-2022年 全国31省对外开放程度、经济发展水平、政府干预程度指标数据https://download.csdn.net/download/2401_84585615/89478612 数据概览 对外…...
计算机网络传输层---课后综合题
线路:TCP报文下放到物理层传输。 TCP报文段中,“序号”长度为32bit,为了让序列号不会循环,则最多能传输2^32B的数据,则最多能传输:2^32/1500B个报文 结果: 吞吐率一个周期内传输的数据/周期时间…...
【homebrew安装】踩坑爬坑教程
homebrew官网,有安装教程提示,但是在实际安装时,由于待下载的包的尺寸过大,本地git缓存尺寸、超时时间的限制,会报如下错误: error: RPC failed; curl 92 HTTP/2 stream 5 was not closed cleanly…...
反游戏学(Reludology):概念、历史、现状与展望?(豆包AI版)
李升伟 以下是关于“反游戏学(Reludology):概念、历史、现状与展望”的综述: 一、概念 反游戏学(Reludology)是一个相对较新且不太常见的概念,目前尚未有统一明确的定义。一般来说…...
【C/C++语言系列】实现单例模式
1.单例模式概念 定义:单例模式是一种常见的设计模式,它可以保证系统中一个类只有一个实例,而且该实例易于外界访问(一个类一个对象,共享这个对象)。 条件: 只有1个对象易于外界访问共享这个对…...
A. Make All Equal
time limit per test 1 second memory limit per test 256 megabytes You are given a cyclic array a1,a2,…,ana1,a2,…,an. You can perform the following operation on aa at most n−1n−1 times: Let mm be the current size of aa, you can choose any two adjac…...
业务安全治理
业务安全治理 1.账号安全撞库账户盗用 2.爬虫与反爬虫3.API网关防护4.钓鱼与反制钓鱼发现钓鱼处置 5.大数据风控风控介绍 1.账号安全 撞库 撞库分为垂直撞库和水平撞库两种,垂直撞库是对一个账号使用多个不同的密码进行尝试,可以理解为暴力破解&#x…...
HelpLook VS GitBook,在线文档管理工具对比
在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…...
docker面经
docker面经在线链接 docker面经在线链接🔗: (https://h03yz7idw7.feishu.cn/wiki/N3CVwO3kMifLypkJqnic9wNynKh)...
Python 中的 Kombu 类库
Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…...
safepoint是什么?有什么用?
在JVM中,safepoint(安全点)是一个非常重要的概念,特别是在垃圾回收(GC)和其他需要暂停所有应用线程的操作中。 什么是safepoint Safepoint是JVM执行过程中一个特定的位置,在这个位置上&#x…...
axios相关知识点
一、基本概念 1、基于Promise:Axios通过Promise实现异步请求,避免了传统回调函数导致的“回调地狱”问题,使得代码更加清晰和易于维护。 2、跨平台:Axios既可以在浏览器中运行,也可以在Node.js环境中使用,为前后端开…...
LeetCode 面试经典150题 67.二进制求和
415.字符串相加 思路一模一样 题目:给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 eg: 输入a“1010” b“1011” 输出“10101” 思路:从右开始遍历两个字符串,因为右边是低位先运算。如果…...
Dell PowerEdge 网络恢复笔记
我有一台Dell的PowerEdge服务器,之前安装了Ubuntu 20 桌面版。突然有一天不能开机了。 故障排查 Disk Error 首先是看一下机器的正面,有一个非常小的液晶显示器,只能显示一排字。 上面显示Disk Error,然后看挂载的硬盘仓&#…...
Java面试——集合篇
1.Java中常用的容器有哪些? 容器主要包括 Collection 和 Map 两种,Collection 存储着对象的集合,而 Map 存储着键值对(两个对象)的映射表。 如图: 面试官追问:说说集合有哪些类及他们各自的区别和特点? S…...
算法【双向广搜】
双向广搜常见用途 1:小优化。bfs的剪枝策略,分两侧展开分支,哪侧数量少就从哪侧展开。 2:用于解决特征很明显的一类问题。特征:全量样本不允许递归完全展开,但是半量样本可以完全展开。过程:把…...
javascript检测数据类型的方法
1. typeof 运算符 typeof是一个用来检测变量的基本数据类型的运算符。它可以返回以下几种类型的字符串:“undefined”、“boolean”、“number”、“string”、“object”、“function” 和 “symbol”(ES6)。需要注意的是,对于 n…...
网络六边形受到攻击
大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...
谷歌浏览器插件
项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...
Android Wi-Fi 连接失败日志分析
1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...
Java 语言特性(面试系列2)
一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...
RocketMQ延迟消息机制
两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数,对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后…...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
线程与协程
1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...
Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...
图表类系列各种样式PPT模版分享
图标图表系列PPT模版,柱状图PPT模版,线状图PPT模版,折线图PPT模版,饼状图PPT模版,雷达图PPT模版,树状图PPT模版 图表类系列各种样式PPT模版分享:图表系列PPT模板https://pan.quark.cn/s/20d40aa…...
GO协程(Goroutine)问题总结
在使用Go语言来编写代码时,遇到的一些问题总结一下 [参考文档]:https://www.topgoer.com/%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B/goroutine.html 1. main()函数默认的Goroutine 场景再现: 今天在看到这个教程的时候,在自己的电…...
