pytorch 实现线性回归(深度学习)
一 查看原始函数
初始化
%matplotlib inline
import random
import torch
from d2l import torch as d2l
1.1 生成原始数据
def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape) # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重
随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1。
w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b
![]()
二 执行训练
查看训练过程中的 参数变化:
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print('w:', w, 'b:', b) # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)


三 测试梯度更新
初始化数据
%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape) # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b
3.1 测试更新
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
# param -= lr * param.grad / batch_size
# param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b) # l:', l, '\nl.sum().backward() # 计算更新梯度sgd([w, b], lr, batch_size)
使用 l.sum().backward() # 计算更新梯度:

不使用更新时:
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
# param -= lr * param.grad / batch_size
# param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b) # l:', l, '\n# l.sum().backward() # 计算更新梯度sgd([w, b], lr, batch_size)# break

相关文章:
pytorch 实现线性回归(深度学习)
一 查看原始函数 初始化 %matplotlib inline import random import torch from d2l import torch as d2l 1.1 生成原始数据 def synthetic_data(w, b, num_examples):x torch.normal(0, 1, (num_examples, len(w)))y torch.matmul(x, w) bprint(x:, x)print(y:, y)y tor…...
[Doris] Doris的安装和部署 (二)
文章目录 1.安装要求1.1 Linux操作系统要求1.2 软件需求1.3 注意事项1.4 内部端口 2.集群部署2.1 操作系统安装要求2.2 下载安装包2.3 解压2.4 配置FE2.5 配置BE2.6 添加BE2.7 FE 扩容和缩容2.8 Doris 集群群起脚本 3.图形化 1.安装要求 1.1 Linux操作系统要求 1.2 软件需求 1…...
【QT+QGIS跨平台编译】之三十五:【cairo+Qt跨平台编译】(一套代码、一套框架,跨平台编译)
文章目录 一、cairo介绍二、文件下载三、文件分析四、pro文件五、编译实践一、cairo介绍 Cairo是一个功能强大的开源2D图形库,它提供了一套跨平台的API,用于绘制矢量图形和文本。Cairo支持多种输出目标,包括屏幕、图像文件、PDF、SVG等。 Cairo的设计目标是简单易用、高效…...
MySQL(基础)
第01章_数据库概述 1. 为什么要使用数据库 持久化(persistence):把数据保存到可掉电式存储设备中以供之后使用。大多数情况下,特别是企业级应用,数据持久化意味着将内存中的数据保存到硬盘上加以”固化”,而持久化的实现过程大多…...
STM32F1 - 中断系统
Interrupt 1> 硬件框图2> NVIC 中断管理3> EXTI 中断管理3.1> EXTI与NVIC3.2> EXTI内部框图 4> 外部中断实验4.1> 实验概述4.2> 程序设计 5> 中断向量表6> 总结 1> 硬件框图 NVIC:Nested Vectored Interrupt Controller【嵌套向量…...
【Linux系统化学习】缓冲区
目录 缓冲区 一个样例 现象解释 缓冲区存在的位置 缓冲区 在刚开始学习C语言的时候我们就听过缓冲区这个名词,很是晦涩难懂;在Linux下进程退出时也包含缓冲区,因此缓冲区到底是什么?有什么作用? 让我们先从一个小…...
基于BP算法的SAR成像matlab仿真
目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 BP算法的基本原理 4.2 BP算法的优点与局限性 5.完整工程文件 1.课题概述 基于BP算法的SAR成像。合成孔径雷达(SAR)是一种高分辨率的雷达系统,能够在各种天气和光…...
【C++ STL】你真的了解string吗?浅谈string的底层实现
文章目录 底层结构概述扩容机制浅拷贝与深拷贝插入和删除的效率浅谈VS和g的优化总结 底层结构概述 string可以帮助我们很好地管理字符串,但是你真的了解她吗?事实上,string的设计是非常复杂的,拥有上百个接口,但最常用…...
17.3.1.3 灰度
版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 灰度的算法主要有以下三种: 1、最大值法: 原图像:颜色值color(R,G,B&a…...
基于CAS操作的atomic原子类型
在上一节的卖票程序中,我们讲解了如何在多线程中保证临界资源的正确访问——使用互斥锁,即 lock_guard<mutex> lock(mtx); count;lock_guard<mutex> lock(mtx); count--; 从汇编角度解释线程间互斥-mutex互斥锁与lock_guard的使用-CSDN博客…...
Rust HashMap详解及单词统计示例
在Rust中,HashMap是一种非常有用的数据结构,用于存储键值对。本文将深入介绍HashMap的特性,以及通过一个单词统计的例子展示其用法。 HashMap简介 HashMap是Rust标准库提供的用于存储键值对的数据结构。它允许通过键快速查找对应的值&#…...
命令执行讲解和函数
命令执行漏洞简介 命令执行漏洞产生原因 应用未对用户输入做严格得检查过滤,导致用户输入得参数被当成命令来执行 命令执行漏洞的危害 1.继承Web服务程序的权限去执行系统命会或读写文件 2.反弹shell,获得目标服务器的权限 3.进一步内网渗透 远程代…...
外包实在是太坑了,划水三年,感觉人都废了
先说一下自己的情况,专科生,19年通过校招进入杭州某个外包软件公司,干了接近3年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了3年的功…...
代码随想录算法训练营第19天
77. 组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 class Solution:def combine(self, n: int, k: int) -> List[List[int]]:path []res []def dfs(n,k,index):if len(path) k:res.append(path[:])returnfor i in range(index,n1):…...
树莓派5 EEPROM引导加载程序恢复镜像
树莓派5不能正常启动,可以通过电源led灯的闪码来判断错误发生的大致情形。 LED警告闪码 如果树莓派由于某种原因无法启动,或者不得不关闭,在许多情况下,LED会闪烁特定的次数来指示发生了什么。LED会闪烁几次长闪烁,然…...
循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例
Markdown 进阶操作 查看更多学习笔记:GitHub:LoveEmiliaForever Mermaid官网 由于CSDN对某些Mermaid或Markdown语法不支持,因此我的某些效果展示使用图片进行 下面的笔记内容全部是我根据Mermaid官方文档学习的,因为是初学者所以…...
寒假作业2月6号
第五章 静态成员与友元 一、填空题 1、一个类的头文件如下所示,num初始化值为5,程序产生对象T,且修改num为10,并使用show()函数输出num的值10。 #include <iostream.h> class Test { private: static int num; publi…...
ChatGPT绘图指南:DALL.E3玩法大全(一)
一、 DALLE.3 模型介绍 1、什么是 DALLE.3 模型? DALLE-3模型,是一种由OpenAI研发的技术,它是一种先进的生成模型,可以将文字描述转化为清晰的图片。这种模型的名称"DALLE"实际上是"Deep Auto-regressive Latent …...
计算机服务器中了_locked勒索病毒怎么办?Encrypted勒索病毒解密数据恢复
随着网络技术的不断发展,数字化办公已经成为企业生产运营的根本,对于企业来说,数据至关重要,但网络威胁无处不在,近期,云天数据恢复中心接到很多企业的求助,企业的计算机服务器遭到了_locked勒索…...
VueCLI核心知识3:全局事件总线、消息订阅与发布
这两种方式都可以实现任意两个组件之间的通信 1 全局事件总线 1.安装全局事件总线 import Vue from vue import App from ./App.vueVue.config.productionTip false/* 1.第一种写法 */ // const Demo Vue.extend({}) // const d new Demo()// Vue.prototype.x d // 把Dem…...
2021-03-15 iview一些问题
1.iview 在使用tree组件时,发现没有set类的方法,只有get,那么要改变tree值,只能遍历treeData,递归修改treeData的checked,发现无法更改,原因在于check模式下,子元素的勾选状态跟父节…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
让AI看见世界:MCP协议与服务器的工作原理
让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...
如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...
Go 语言并发编程基础:无缓冲与有缓冲通道
在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好࿰…...
React从基础入门到高级实战:React 实战项目 - 项目五:微前端与模块化架构
React 实战项目:微前端与模块化架构 欢迎来到 React 开发教程专栏 的第 30 篇!在前 29 篇文章中,我们从 React 的基础概念逐步深入到高级技巧,涵盖了组件设计、状态管理、路由配置、性能优化和企业级应用等核心内容。这一次&…...
ThreadLocal 源码
ThreadLocal 源码 此类提供线程局部变量。这些变量不同于它们的普通对应物,因为每个访问一个线程局部变量的线程(通过其 get 或 set 方法)都有自己独立初始化的变量副本。ThreadLocal 实例通常是类中的私有静态字段,这些类希望将…...
2.2.2 ASPICE的需求分析
ASPICE的需求分析是汽车软件开发过程中至关重要的一环,它涉及到对需求进行详细分析、验证和确认,以确保软件产品能够满足客户和用户的需求。在ASPICE中,需求分析的关键步骤包括: 需求细化:将从需求收集阶段获得的高层需…...
Tauri2学习笔记
教程地址:https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引:https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多,我按照Tauri1的教程来学习&…...
