当前位置: 首页 > news >正文

pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print('w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n# l.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)#     break

相关文章:

pytorch 实现线性回归(深度学习)

一 查看原始函数 初始化 %matplotlib inline import random import torch from d2l import torch as d2l 1.1 生成原始数据 def synthetic_data(w, b, num_examples):x torch.normal(0, 1, (num_examples, len(w)))y torch.matmul(x, w) bprint(x:, x)print(y:, y)y tor…...

[Doris] Doris的安装和部署 (二)

文章目录 1.安装要求1.1 Linux操作系统要求1.2 软件需求1.3 注意事项1.4 内部端口 2.集群部署2.1 操作系统安装要求2.2 下载安装包2.3 解压2.4 配置FE2.5 配置BE2.6 添加BE2.7 FE 扩容和缩容2.8 Doris 集群群起脚本 3.图形化 1.安装要求 1.1 Linux操作系统要求 1.2 软件需求 1…...

【QT+QGIS跨平台编译】之三十五:【cairo+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、cairo介绍二、文件下载三、文件分析四、pro文件五、编译实践一、cairo介绍 Cairo是一个功能强大的开源2D图形库,它提供了一套跨平台的API,用于绘制矢量图形和文本。Cairo支持多种输出目标,包括屏幕、图像文件、PDF、SVG等。 Cairo的设计目标是简单易用、高效…...

MySQL(基础)

第01章_数据库概述 1. 为什么要使用数据库 持久化(persistence):把数据保存到可掉电式存储设备中以供之后使用。大多数情况下,特别是企业级应用,数据持久化意味着将内存中的数据保存到硬盘上加以”固化”,而持久化的实现过程大多…...

STM32F1 - 中断系统

Interrupt 1> 硬件框图2> NVIC 中断管理3> EXTI 中断管理3.1> EXTI与NVIC3.2> EXTI内部框图 4> 外部中断实验4.1> 实验概述4.2> 程序设计 5> 中断向量表6> 总结 1> 硬件框图 NVIC:Nested Vectored Interrupt Controller【嵌套向量…...

【Linux系统化学习】缓冲区

目录 缓冲区 一个样例 现象解释 缓冲区存在的位置 缓冲区 在刚开始学习C语言的时候我们就听过缓冲区这个名词,很是晦涩难懂;在Linux下进程退出时也包含缓冲区,因此缓冲区到底是什么?有什么作用? 让我们先从一个小…...

基于BP算法的SAR成像matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 BP算法的基本原理 4.2 BP算法的优点与局限性 5.完整工程文件 1.课题概述 基于BP算法的SAR成像。合成孔径雷达(SAR)是一种高分辨率的雷达系统,能够在各种天气和光…...

【C++ STL】你真的了解string吗?浅谈string的底层实现

文章目录 底层结构概述扩容机制浅拷贝与深拷贝插入和删除的效率浅谈VS和g的优化总结 底层结构概述 string可以帮助我们很好地管理字符串,但是你真的了解她吗?事实上,string的设计是非常复杂的,拥有上百个接口,但最常用…...

17.3.1.3 灰度

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 灰度的算法主要有以下三种: 1、最大值法: 原图像:颜色值color(R,G,B&a…...

基于CAS操作的atomic原子类型

在上一节的卖票程序中&#xff0c;我们讲解了如何在多线程中保证临界资源的正确访问——使用互斥锁&#xff0c;即 lock_guard<mutex> lock(mtx); count;lock_guard<mutex> lock(mtx); count--; 从汇编角度解释线程间互斥-mutex互斥锁与lock_guard的使用-CSDN博客…...

Rust HashMap详解及单词统计示例

在Rust中&#xff0c;HashMap是一种非常有用的数据结构&#xff0c;用于存储键值对。本文将深入介绍HashMap的特性&#xff0c;以及通过一个单词统计的例子展示其用法。 HashMap简介 HashMap是Rust标准库提供的用于存储键值对的数据结构。它允许通过键快速查找对应的值&#…...

命令执行讲解和函数

命令执行漏洞简介 命令执行漏洞产生原因 应用未对用户输入做严格得检查过滤&#xff0c;导致用户输入得参数被当成命令来执行 命令执行漏洞的危害 1.继承Web服务程序的权限去执行系统命会或读写文件 2.反弹shell&#xff0c;获得目标服务器的权限 3.进一步内网渗透 远程代…...

外包实在是太坑了,划水三年,感觉人都废了

先说一下自己的情况&#xff0c;专科生&#xff0c;19年通过校招进入杭州某个外包软件公司&#xff0c;干了接近3年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了3年的功…...

代码随想录算法训练营第19天

77. 组合 给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 class Solution:def combine(self, n: int, k: int) -> List[List[int]]:path []res []def dfs(n,k,index):if len(path) k:res.append(path[:])returnfor i in range(index,n1):…...

树莓派5 EEPROM引导加载程序恢复镜像

树莓派5不能正常启动&#xff0c;可以通过电源led灯的闪码来判断错误发生的大致情形。 LED警告闪码 如果树莓派由于某种原因无法启动&#xff0c;或者不得不关闭&#xff0c;在许多情况下&#xff0c;LED会闪烁特定的次数来指示发生了什么。LED会闪烁几次长闪烁&#xff0c;然…...

循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例

Markdown 进阶操作 查看更多学习笔记&#xff1a;GitHub&#xff1a;LoveEmiliaForever Mermaid官网 由于CSDN对某些Mermaid或Markdown语法不支持&#xff0c;因此我的某些效果展示使用图片进行 下面的笔记内容全部是我根据Mermaid官方文档学习的&#xff0c;因为是初学者所以…...

寒假作业2月6号

第五章 静态成员与友元 一、填空题 1、一个类的头文件如下所示&#xff0c;num初始化值为5&#xff0c;程序产生对象T&#xff0c;且修改num为10&#xff0c;并使用show()函数输出num的值10。 #include <iostream.h> class Test { private: static int num; publi…...

ChatGPT绘图指南:DALL.E3玩法大全(一)

一、 DALLE.3 模型介绍 1、什么是 DALLE.3 模型&#xff1f; DALLE-3模型&#xff0c;是一种由OpenAI研发的技术&#xff0c;它是一种先进的生成模型&#xff0c;可以将文字描述转化为清晰的图片。这种模型的名称"DALLE"实际上是"Deep Auto-regressive Latent …...

计算机服务器中了_locked勒索病毒怎么办?Encrypted勒索病毒解密数据恢复

随着网络技术的不断发展&#xff0c;数字化办公已经成为企业生产运营的根本&#xff0c;对于企业来说&#xff0c;数据至关重要&#xff0c;但网络威胁无处不在&#xff0c;近期&#xff0c;云天数据恢复中心接到很多企业的求助&#xff0c;企业的计算机服务器遭到了_locked勒索…...

VueCLI核心知识3:全局事件总线、消息订阅与发布

这两种方式都可以实现任意两个组件之间的通信 1 全局事件总线 1.安装全局事件总线 import Vue from vue import App from ./App.vueVue.config.productionTip false/* 1.第一种写法 */ // const Demo Vue.extend({}) // const d new Demo()// Vue.prototype.x d // 把Dem…...

2021-03-15 iview一些问题

1.iview 在使用tree组件时&#xff0c;发现没有set类的方法&#xff0c;只有get&#xff0c;那么要改变tree值&#xff0c;只能遍历treeData&#xff0c;递归修改treeData的checked&#xff0c;发现无法更改&#xff0c;原因在于check模式下&#xff0c;子元素的勾选状态跟父节…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

让AI看见世界:MCP协议与服务器的工作原理

让AI看见世界&#xff1a;MCP协议与服务器的工作原理 MCP&#xff08;Model Context Protocol&#xff09;是一种创新的通信协议&#xff0c;旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天&#xff0c;MCP正成为连接AI与现实世界的重要桥梁。…...

如何在网页里填写 PDF 表格?

有时候&#xff0c;你可能希望用户能在你的网站上填写 PDF 表单。然而&#xff0c;这件事并不简单&#xff0c;因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件&#xff0c;但原生并不支持编辑或填写它们。更糟的是&#xff0c;如果你想收集表单数据&#xff…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中&#xff0c;我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道&#xff0c;它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

React从基础入门到高级实战:React 实战项目 - 项目五:微前端与模块化架构

React 实战项目&#xff1a;微前端与模块化架构 欢迎来到 React 开发教程专栏 的第 30 篇&#xff01;在前 29 篇文章中&#xff0c;我们从 React 的基础概念逐步深入到高级技巧&#xff0c;涵盖了组件设计、状态管理、路由配置、性能优化和企业级应用等核心内容。这一次&…...

ThreadLocal 源码

ThreadLocal 源码 此类提供线程局部变量。这些变量不同于它们的普通对应物&#xff0c;因为每个访问一个线程局部变量的线程&#xff08;通过其 get 或 set 方法&#xff09;都有自己独立初始化的变量副本。ThreadLocal 实例通常是类中的私有静态字段&#xff0c;这些类希望将…...

2.2.2 ASPICE的需求分析

ASPICE的需求分析是汽车软件开发过程中至关重要的一环&#xff0c;它涉及到对需求进行详细分析、验证和确认&#xff0c;以确保软件产品能够满足客户和用户的需求。在ASPICE中&#xff0c;需求分析的关键步骤包括&#xff1a; 需求细化&#xff1a;将从需求收集阶段获得的高层需…...

Tauri2学习笔记

教程地址&#xff1a;https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引&#xff1a;https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多&#xff0c;我按照Tauri1的教程来学习&…...