001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归
0. 背景介绍
我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt
1. 生成数据集合(待拟合)
使用python生成待拟合的数据
num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels = w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])
2.数据的分批量处理
def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)
3. 模型构建及训练
3.1 定义模型:
def linreg(X, w, b):return torch.mm(X,w)+b
3.2 定义损失函数
def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2
3.3 定义优化算法
def sgd(params , lr ,batch_size):for param in params:param.data -= lr * param.grad / batch_size
3.4 模型训练
# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))
epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053
基于pytorch的线性模型的实现
- 相关数据和初始化与上面构建相同
- 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net)
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
- 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
- 定义损失函数
loss = nn.MSELoss()
5.定义优化算法
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
- 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )
epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066
相关文章:

001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R10002 …...

Relations Prediction for Knowledge Graph Completion using Large Language Models
文章目录 题目摘要简介相关工作方法论实验结论局限性未来工作 题目 使用大型语言模型进行知识图谱补全的关系预测 论文地址:https://arxiv.org/pdf/2405.02738 项目地址: https://github.com/yao8839836/kg-llm 摘要 知识图谱已被广泛用于以结构化格式表…...
2024年中国研究生数学建模竞赛D题思路代码分析——大数据驱动的地理综合问题
地理系统是自然、人文多要素综合作用的复杂巨系统[1-2],地理学家常用地理综合的方式对地理系统进行主导特征的表达[3]。如以三大阶梯概括中国的地形特征,以秦岭—淮河一线和其它地理区划的方式揭示中国气温、降水、植被、土壤及生态环境在水平和垂直方向…...

全国31省对外开放程度、经济发展水平、政府干预程度指标数据(2000-2022年)
旨在分析2000-2022年间中国31个省份的对外开放程度、经济发展水平和政府干预程度,探讨其背后的动因与影响。 2000年-2022年 全国31省对外开放程度、经济发展水平、政府干预程度指标数据https://download.csdn.net/download/2401_84585615/89478612 数据概览 对外…...

计算机网络传输层---课后综合题
线路:TCP报文下放到物理层传输。 TCP报文段中,“序号”长度为32bit,为了让序列号不会循环,则最多能传输2^32B的数据,则最多能传输:2^32/1500B个报文 结果: 吞吐率一个周期内传输的数据/周期时间…...

【homebrew安装】踩坑爬坑教程
homebrew官网,有安装教程提示,但是在实际安装时,由于待下载的包的尺寸过大,本地git缓存尺寸、超时时间的限制,会报如下错误: error: RPC failed; curl 92 HTTP/2 stream 5 was not closed cleanly…...
反游戏学(Reludology):概念、历史、现状与展望?(豆包AI版)
李升伟 以下是关于“反游戏学(Reludology):概念、历史、现状与展望”的综述: 一、概念 反游戏学(Reludology)是一个相对较新且不太常见的概念,目前尚未有统一明确的定义。一般来说…...
【C/C++语言系列】实现单例模式
1.单例模式概念 定义:单例模式是一种常见的设计模式,它可以保证系统中一个类只有一个实例,而且该实例易于外界访问(一个类一个对象,共享这个对象)。 条件: 只有1个对象易于外界访问共享这个对…...
A. Make All Equal
time limit per test 1 second memory limit per test 256 megabytes You are given a cyclic array a1,a2,…,ana1,a2,…,an. You can perform the following operation on aa at most n−1n−1 times: Let mm be the current size of aa, you can choose any two adjac…...

业务安全治理
业务安全治理 1.账号安全撞库账户盗用 2.爬虫与反爬虫3.API网关防护4.钓鱼与反制钓鱼发现钓鱼处置 5.大数据风控风控介绍 1.账号安全 撞库 撞库分为垂直撞库和水平撞库两种,垂直撞库是对一个账号使用多个不同的密码进行尝试,可以理解为暴力破解&#x…...

HelpLook VS GitBook,在线文档管理工具对比
在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…...
docker面经
docker面经在线链接 docker面经在线链接🔗: (https://h03yz7idw7.feishu.cn/wiki/N3CVwO3kMifLypkJqnic9wNynKh)...

Python 中的 Kombu 类库
Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…...
safepoint是什么?有什么用?
在JVM中,safepoint(安全点)是一个非常重要的概念,特别是在垃圾回收(GC)和其他需要暂停所有应用线程的操作中。 什么是safepoint Safepoint是JVM执行过程中一个特定的位置,在这个位置上&#x…...
axios相关知识点
一、基本概念 1、基于Promise:Axios通过Promise实现异步请求,避免了传统回调函数导致的“回调地狱”问题,使得代码更加清晰和易于维护。 2、跨平台:Axios既可以在浏览器中运行,也可以在Node.js环境中使用,为前后端开…...
LeetCode 面试经典150题 67.二进制求和
415.字符串相加 思路一模一样 题目:给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 eg: 输入a“1010” b“1011” 输出“10101” 思路:从右开始遍历两个字符串,因为右边是低位先运算。如果…...

Dell PowerEdge 网络恢复笔记
我有一台Dell的PowerEdge服务器,之前安装了Ubuntu 20 桌面版。突然有一天不能开机了。 故障排查 Disk Error 首先是看一下机器的正面,有一个非常小的液晶显示器,只能显示一排字。 上面显示Disk Error,然后看挂载的硬盘仓&#…...

Java面试——集合篇
1.Java中常用的容器有哪些? 容器主要包括 Collection 和 Map 两种,Collection 存储着对象的集合,而 Map 存储着键值对(两个对象)的映射表。 如图: 面试官追问:说说集合有哪些类及他们各自的区别和特点? S…...
算法【双向广搜】
双向广搜常见用途 1:小优化。bfs的剪枝策略,分两侧展开分支,哪侧数量少就从哪侧展开。 2:用于解决特征很明显的一类问题。特征:全量样本不允许递归完全展开,但是半量样本可以完全展开。过程:把…...
javascript检测数据类型的方法
1. typeof 运算符 typeof是一个用来检测变量的基本数据类型的运算符。它可以返回以下几种类型的字符串:“undefined”、“boolean”、“number”、“string”、“object”、“function” 和 “symbol”(ES6)。需要注意的是,对于 n…...

Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
Python爬虫实战:研究feedparser库相关技术
1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...

基于PHP的连锁酒店管理系统
有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发,数据库mysql,前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

Linux-进程间的通信
1、IPC: Inter Process Communication(进程间通信): 由于每个进程在操作系统中有独立的地址空间,它们不能像线程那样直接访问彼此的内存,所以必须通过某种方式进行通信。 常见的 IPC 方式包括&#…...

数据分析六部曲?
引言 上一章我们说到了数据分析六部曲,何谓六部曲呢? 其实啊,数据分析没那么难,只要掌握了下面这六个步骤,也就是数据分析六部曲,就算你是个啥都不懂的小白,也能慢慢上手做数据分析啦。 第一…...

如何做好一份技术文档?从规划到实践的完整指南
如何做好一份技术文档?从规划到实践的完整指南 🌟 嗨,我是IRpickstars! 🌌 总有一行代码,能点亮万千星辰。 🔍 在技术的宇宙中,我愿做永不停歇的探索者。 ✨ 用代码丈量世界&…...

WinUI3开发_使用mica效果
简介 Mica(云母)是Windows10/11上的一种现代化效果,是Windows10/11上所使用的Fluent Design(设计语言)里的一个效果,Windows10/11上所使用的Fluent Design皆旨在于打造一个人类、通用和真正感觉与 Windows 一样的设计。 WinUI3就是Windows10/11上的一个…...
Flask和Django,你怎么选?
Flask 和 Django 是 Python 两大最流行的 Web 框架,但它们的设计哲学、目标和适用场景有显著区别。以下是详细的对比: 核心区别:哲学与定位 Django: 定位: "全栈式" Web 框架。奉行"开箱即用"的理念。 哲学: "包含…...