当前位置: 首页 > news >正文

pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print('w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n# l.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)#     break

相关文章:

pytorch 实现线性回归(深度学习)

一 查看原始函数 初始化 %matplotlib inline import random import torch from d2l import torch as d2l 1.1 生成原始数据 def synthetic_data(w, b, num_examples):x torch.normal(0, 1, (num_examples, len(w)))y torch.matmul(x, w) bprint(x:, x)print(y:, y)y tor…...

[Doris] Doris的安装和部署 (二)

文章目录 1.安装要求1.1 Linux操作系统要求1.2 软件需求1.3 注意事项1.4 内部端口 2.集群部署2.1 操作系统安装要求2.2 下载安装包2.3 解压2.4 配置FE2.5 配置BE2.6 添加BE2.7 FE 扩容和缩容2.8 Doris 集群群起脚本 3.图形化 1.安装要求 1.1 Linux操作系统要求 1.2 软件需求 1…...

【QT+QGIS跨平台编译】之三十五:【cairo+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、cairo介绍二、文件下载三、文件分析四、pro文件五、编译实践一、cairo介绍 Cairo是一个功能强大的开源2D图形库,它提供了一套跨平台的API,用于绘制矢量图形和文本。Cairo支持多种输出目标,包括屏幕、图像文件、PDF、SVG等。 Cairo的设计目标是简单易用、高效…...

MySQL(基础)

第01章_数据库概述 1. 为什么要使用数据库 持久化(persistence):把数据保存到可掉电式存储设备中以供之后使用。大多数情况下,特别是企业级应用,数据持久化意味着将内存中的数据保存到硬盘上加以”固化”,而持久化的实现过程大多…...

STM32F1 - 中断系统

Interrupt 1> 硬件框图2> NVIC 中断管理3> EXTI 中断管理3.1> EXTI与NVIC3.2> EXTI内部框图 4> 外部中断实验4.1> 实验概述4.2> 程序设计 5> 中断向量表6> 总结 1> 硬件框图 NVIC:Nested Vectored Interrupt Controller【嵌套向量…...

【Linux系统化学习】缓冲区

目录 缓冲区 一个样例 现象解释 缓冲区存在的位置 缓冲区 在刚开始学习C语言的时候我们就听过缓冲区这个名词,很是晦涩难懂;在Linux下进程退出时也包含缓冲区,因此缓冲区到底是什么?有什么作用? 让我们先从一个小…...

基于BP算法的SAR成像matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 BP算法的基本原理 4.2 BP算法的优点与局限性 5.完整工程文件 1.课题概述 基于BP算法的SAR成像。合成孔径雷达(SAR)是一种高分辨率的雷达系统,能够在各种天气和光…...

【C++ STL】你真的了解string吗?浅谈string的底层实现

文章目录 底层结构概述扩容机制浅拷贝与深拷贝插入和删除的效率浅谈VS和g的优化总结 底层结构概述 string可以帮助我们很好地管理字符串,但是你真的了解她吗?事实上,string的设计是非常复杂的,拥有上百个接口,但最常用…...

17.3.1.3 灰度

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 灰度的算法主要有以下三种: 1、最大值法: 原图像:颜色值color(R,G,B&a…...

基于CAS操作的atomic原子类型

在上一节的卖票程序中&#xff0c;我们讲解了如何在多线程中保证临界资源的正确访问——使用互斥锁&#xff0c;即 lock_guard<mutex> lock(mtx); count;lock_guard<mutex> lock(mtx); count--; 从汇编角度解释线程间互斥-mutex互斥锁与lock_guard的使用-CSDN博客…...

Rust HashMap详解及单词统计示例

在Rust中&#xff0c;HashMap是一种非常有用的数据结构&#xff0c;用于存储键值对。本文将深入介绍HashMap的特性&#xff0c;以及通过一个单词统计的例子展示其用法。 HashMap简介 HashMap是Rust标准库提供的用于存储键值对的数据结构。它允许通过键快速查找对应的值&#…...

命令执行讲解和函数

命令执行漏洞简介 命令执行漏洞产生原因 应用未对用户输入做严格得检查过滤&#xff0c;导致用户输入得参数被当成命令来执行 命令执行漏洞的危害 1.继承Web服务程序的权限去执行系统命会或读写文件 2.反弹shell&#xff0c;获得目标服务器的权限 3.进一步内网渗透 远程代…...

外包实在是太坑了,划水三年,感觉人都废了

先说一下自己的情况&#xff0c;专科生&#xff0c;19年通过校招进入杭州某个外包软件公司&#xff0c;干了接近3年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了3年的功…...

代码随想录算法训练营第19天

77. 组合 给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 class Solution:def combine(self, n: int, k: int) -> List[List[int]]:path []res []def dfs(n,k,index):if len(path) k:res.append(path[:])returnfor i in range(index,n1):…...

树莓派5 EEPROM引导加载程序恢复镜像

树莓派5不能正常启动&#xff0c;可以通过电源led灯的闪码来判断错误发生的大致情形。 LED警告闪码 如果树莓派由于某种原因无法启动&#xff0c;或者不得不关闭&#xff0c;在许多情况下&#xff0c;LED会闪烁特定的次数来指示发生了什么。LED会闪烁几次长闪烁&#xff0c;然…...

循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例

Markdown 进阶操作 查看更多学习笔记&#xff1a;GitHub&#xff1a;LoveEmiliaForever Mermaid官网 由于CSDN对某些Mermaid或Markdown语法不支持&#xff0c;因此我的某些效果展示使用图片进行 下面的笔记内容全部是我根据Mermaid官方文档学习的&#xff0c;因为是初学者所以…...

寒假作业2月6号

第五章 静态成员与友元 一、填空题 1、一个类的头文件如下所示&#xff0c;num初始化值为5&#xff0c;程序产生对象T&#xff0c;且修改num为10&#xff0c;并使用show()函数输出num的值10。 #include <iostream.h> class Test { private: static int num; publi…...

ChatGPT绘图指南:DALL.E3玩法大全(一)

一、 DALLE.3 模型介绍 1、什么是 DALLE.3 模型&#xff1f; DALLE-3模型&#xff0c;是一种由OpenAI研发的技术&#xff0c;它是一种先进的生成模型&#xff0c;可以将文字描述转化为清晰的图片。这种模型的名称"DALLE"实际上是"Deep Auto-regressive Latent …...

计算机服务器中了_locked勒索病毒怎么办?Encrypted勒索病毒解密数据恢复

随着网络技术的不断发展&#xff0c;数字化办公已经成为企业生产运营的根本&#xff0c;对于企业来说&#xff0c;数据至关重要&#xff0c;但网络威胁无处不在&#xff0c;近期&#xff0c;云天数据恢复中心接到很多企业的求助&#xff0c;企业的计算机服务器遭到了_locked勒索…...

VueCLI核心知识3:全局事件总线、消息订阅与发布

这两种方式都可以实现任意两个组件之间的通信 1 全局事件总线 1.安装全局事件总线 import Vue from vue import App from ./App.vueVue.config.productionTip false/* 1.第一种写法 */ // const Demo Vue.extend({}) // const d new Demo()// Vue.prototype.x d // 把Dem…...

阿里云ACP云计算备考笔记 (5)——弹性伸缩

目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI

前一阵子在百度 AI 开发者大会上&#xff0c;看到基于小智 AI DIY 玩具的演示&#xff0c;感觉有点意思&#xff0c;想着自己也来试试。 如果只是想烧录现成的固件&#xff0c;乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外&#xff0c;还提供了基于网页版的 ESP LA…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

LeetCode - 199. 二叉树的右视图

题目 199. 二叉树的右视图 - 力扣&#xff08;LeetCode&#xff09; 思路 右视图是指从树的右侧看&#xff0c;对于每一层&#xff0c;只能看到该层最右边的节点。实现思路是&#xff1a; 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文通过代码驱动的方式&#xff0c;系统讲解PyTorch核心概念和实战技巧&#xff0c;涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...