当前位置: 首页 > news >正文

001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):for param in params:param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

相关文章:

001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R10002 …...

Relations Prediction for Knowledge Graph Completion using Large Language Models

文章目录 题目摘要简介相关工作方法论实验结论局限性未来工作 题目 使用大型语言模型进行知识图谱补全的关系预测 论文地址:https://arxiv.org/pdf/2405.02738 项目地址: https://github.com/yao8839836/kg-llm 摘要 知识图谱已被广泛用于以结构化格式表…...

2024年中国研究生数学建模竞赛D题思路代码分析——大数据驱动的地理综合问题

地理系统是自然、人文多要素综合作用的复杂巨系统[1-2],地理学家常用地理综合的方式对地理系统进行主导特征的表达[3]。如以三大阶梯概括中国的地形特征,以秦岭—淮河一线和其它地理区划的方式揭示中国气温、降水、植被、土壤及生态环境在水平和垂直方向…...

全国31省对外开放程度、经济发展水平、政府干预程度指标数据(2000-2022年)

旨在分析2000-2022年间中国31个省份的对外开放程度、经济发展水平和政府干预程度,探讨其背后的动因与影响。 2000年-2022年 全国31省对外开放程度、经济发展水平、政府干预程度指标数据https://download.csdn.net/download/2401_84585615/89478612 数据概览 对外…...

计算机网络传输层---课后综合题

线路:TCP报文下放到物理层传输。 TCP报文段中,“序号”长度为32bit,为了让序列号不会循环,则最多能传输2^32B的数据,则最多能传输:2^32/1500B个报文 结果: 吞吐率一个周期内传输的数据/周期时间…...

【homebrew安装】踩坑爬坑教程

homebrew官网,有安装教程提示,但是在实际安装时,由于待下载的包的尺寸过大,本地git缓存尺寸、超时时间的限制,会报如下错误: error: RPC failed; curl 92 HTTP/2 stream 5 was not closed cleanly&#xf…...

反游戏学(Reludology):概念、历史、现状与展望?(豆包AI版)

李升伟 以下是关于“反游戏学(Reludology):概念、历史、现状与展望”的综述: 一、概念 反游戏学(Reludology)是一个相对较新且不太常见的概念,目前尚未有统一明确的定义。一般来说&#xf…...

【C/C++语言系列】实现单例模式

1.单例模式概念 定义:单例模式是一种常见的设计模式,它可以保证系统中一个类只有一个实例,而且该实例易于外界访问(一个类一个对象,共享这个对象)。 条件: 只有1个对象易于外界访问共享这个对…...

A. Make All Equal

time limit per test 1 second memory limit per test 256 megabytes You are given a cyclic array a1,a2,…,ana1,a2,…,an. You can perform the following operation on aa at most n−1n−1 times: Let mm be the current size of aa, you can choose any two adjac…...

业务安全治理

业务安全治理 1.账号安全撞库账户盗用 2.爬虫与反爬虫3.API网关防护4.钓鱼与反制钓鱼发现钓鱼处置 5.大数据风控风控介绍 1.账号安全 撞库 撞库分为垂直撞库和水平撞库两种,垂直撞库是对一个账号使用多个不同的密码进行尝试,可以理解为暴力破解&#x…...

HelpLook VS GitBook,在线文档管理工具对比

在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…...

docker面经

docker面经在线链接 docker面经在线链接🔗: (https://h03yz7idw7.feishu.cn/wiki/N3CVwO3kMifLypkJqnic9wNynKh)...

Python 中的 Kombu 类库

Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…...

safepoint是什么?有什么用?

在JVM中,safepoint(安全点)是一个非常重要的概念,特别是在垃圾回收(GC)和其他需要暂停所有应用线程的操作中。 什么是safepoint Safepoint是JVM执行过程中一个特定的位置,在这个位置上&#x…...

axios相关知识点

一、基本概念 1、基于Promise:Axios通过Promise实现异步请求,避免了传统回调函数导致的“回调地狱”问题,使得代码更加清晰和易于维护。 2、跨平台:Axios既可以在浏览器中运行,也可以在Node.js环境中使用,为前后端开…...

LeetCode 面试经典150题 67.二进制求和

415.字符串相加 思路一模一样 题目:给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 eg: 输入a“1010” b“1011” 输出“10101” 思路:从右开始遍历两个字符串,因为右边是低位先运算。如果…...

Dell PowerEdge 网络恢复笔记

我有一台Dell的PowerEdge服务器,之前安装了Ubuntu 20 桌面版。突然有一天不能开机了。 故障排查 Disk Error 首先是看一下机器的正面,有一个非常小的液晶显示器,只能显示一排字。 上面显示Disk Error,然后看挂载的硬盘仓&#…...

Java面试——集合篇

1.Java中常用的容器有哪些? 容器主要包括 Collection 和 Map 两种,Collection 存储着对象的集合,而 Map 存储着键值对(两个对象)的映射表。 如图: 面试官追问:说说集合有哪些类及他们各自的区别和特点? S…...

算法【双向广搜】

双向广搜常见用途 1:小优化。bfs的剪枝策略,分两侧展开分支,哪侧数量少就从哪侧展开。 2:用于解决特征很明显的一类问题。特征:全量样本不允许递归完全展开,但是半量样本可以完全展开。过程:把…...

javascript检测数据类型的方法

1. typeof 运算符 typeof是一个用来检测变量的基本数据类型的运算符。它可以返回以下几种类型的字符串:“undefined”、“boolean”、“number”、“string”、“object”、“function” 和 “symbol”(ES6)。需要注意的是,对于 n…...

网络六边形受到攻击

大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数,对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

PHP和Node.js哪个更爽?

先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...

Nuxt.js 中的路由配置详解

Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版,柱状图PPT模版,线状图PPT模版,折线图PPT模版,饼状图PPT模版,雷达图PPT模版,树状图PPT模版 图表类系列各种样式PPT模版分享:图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

GO协程(Goroutine)问题总结

在使用Go语言来编写代码时,遇到的一些问题总结一下 [参考文档]:https://www.topgoer.com/%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B/goroutine.html 1. main()函数默认的Goroutine 场景再现: 今天在看到这个教程的时候,在自己的电…...